diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 50014ef..252674d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/example.py b/example.py index dea19e6..82fd81c 100644 --- a/example.py +++ b/example.py @@ -75,10 +75,10 @@ "parseDOM": [{"tag": "a[href]"}], }, "em": { - "parseDOM": [{"tag": "i"}, {"tag": "em"}, {"style": "font-style=italic"}] + "parseDOM": [{"tag": "i"}, {"tag": "em"}, {"style": "font-style=italic"}], }, "strong": { - "parseDOM": [{"tag": "strong"}, {"tag": "b"}, {"style": "font-weight"}] + "parseDOM": [{"tag": "strong"}, {"tag": "b"}, {"style": "font-weight"}], }, "code": {"parseDOM": [{"tag": "code"}]}, }, diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 3bf6217..31d4b84 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -3,14 +3,11 @@ from typing import ( TYPE_CHECKING, ClassVar, - Dict, - List, Literal, NamedTuple, NoReturn, Optional, TypedDict, - Union, cast, ) @@ -32,11 +29,9 @@ def __init__(self, type: "NodeType", next: "ContentMatch") -> None: class WrapCacheEntry: target: "NodeType" - computed: Optional[List["NodeType"]] + computed: list["NodeType"] | None - def __init__( - self, target: "NodeType", computed: Optional[List["NodeType"]] - ) -> None: + def __init__(self, target: "NodeType", computed: list["NodeType"] | None) -> None: self.target = target self.computed = computed @@ -57,8 +52,8 @@ class ContentMatch: empty: ClassVar["ContentMatch"] valid_end: bool - next: List[MatchEdge] - wrap_cache: List[WrapCacheEntry] + next: list[MatchEdge] + wrap_cache: list[WrapCacheEntry] def __init__(self, valid_end: bool) -> None: self.valid_end = valid_end @@ -66,7 +61,7 @@ def __init__(self, valid_end: bool) -> None: self.wrap_cache = [] @classmethod - def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch": + def parse(cls, string: str, node_types: dict[str, "NodeType"]) -> "ContentMatch": stream = TokenStream(string, node_types) if stream.next() is None: return ContentMatch.empty @@ -84,11 +79,14 @@ def match_type(self, type: "NodeType") -> Optional["ContentMatch"]: return None def match_fragment( - self, frag: Fragment, start: int = 0, end: Optional[int] = None + self, + frag: Fragment, + start: int = 0, + end: int | None = None, ) -> Optional["ContentMatch"]: if end is None: end = frag.child_count - cur: Optional["ContentMatch"] = self + cur: "ContentMatch" | None = self i = start while cur and i < end: cur = cur.match_type(frag.child(i).type) @@ -115,11 +113,14 @@ def compatible(self, other: "ContentMatch") -> bool: return False def fill_before( - self, after: Fragment, to_end: bool = False, start_index: int = 0 - ) -> Optional[Fragment]: + self, + after: Fragment, + to_end: bool = False, + start_index: int = 0, + ) -> Fragment | None: seen = [self] - def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: + def search(match: ContentMatch, types: list["NodeType"]) -> Fragment | None: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): @@ -138,7 +139,7 @@ def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: return search(self, []) - def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: + def find_wrapping(self, target: "NodeType") -> list["NodeType"] | None: for entry in self.wrap_cache: if entry.target.name == target.name: return entry.computed @@ -146,9 +147,9 @@ def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: self.wrap_cache.append(WrapCacheEntry(target, computed)) return computed - def compute_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: + def compute_wrapping(self, target: "NodeType") -> list["NodeType"] | None: seen = {} - active: List[Active] = [{"match": self, "type": None, "via": None}] + active: list[Active] = [{"match": self, "type": None, "via": None}] while len(active): current = active.pop(0) match = current["match"] @@ -181,7 +182,8 @@ def edge_count(self) -> int: def edge(self, n: int) -> MatchEdge: if n >= len(self.next): - raise ValueError(f"There's no {n}th edge in this content match") + msg = f"There's no {n}th edge in this content match" + raise ValueError(msg) return self.next[n] def __str__(self) -> str: @@ -217,23 +219,23 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: - inline: Optional[bool] - tokens: List[str] + inline: bool | None + tokens: list[str] - def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: + def __init__(self, string: str, node_types: dict[str, "NodeType"]) -> None: self.string = string self.node_types = node_types self.inline = None self.pos = 0 self.tokens = [i for i in TOKEN_REGEX.findall(string) if i.strip()] - def next(self) -> Optional[str]: + def next(self) -> str | None: try: return self.tokens[self.pos] except IndexError: return None - def eat(self, tok: str) -> Union[int, bool]: + def eat(self, tok: str) -> int | bool: if self.next() == tok: pos = self.pos self.pos += 1 @@ -242,17 +244,18 @@ def eat(self, tok: str) -> Union[int, bool]: return False def err(self, str: str) -> NoReturn: - raise SyntaxError(f'{str} (in content expression) "{self.string}"') + msg = f'{str} (in content expression) "{self.string}"' + raise SyntaxError(msg) class ChoiceExpr(TypedDict): type: Literal["choice"] - exprs: List["Expr"] + exprs: list["Expr"] class SeqExpr(TypedDict): type: Literal["seq"] - exprs: List["Expr"] + exprs: list["Expr"] class PlusExpr(TypedDict): @@ -282,7 +285,7 @@ class NameExpr(TypedDict): value: "NodeType" -Expr = Union[ChoiceExpr, SeqExpr, PlusExpr, StarExpr, OptExpr, RangeExpr, NameExpr] +Expr = ChoiceExpr | SeqExpr | PlusExpr | StarExpr | OptExpr | RangeExpr | NameExpr def parse_expr(stream: TokenStream) -> Expr: @@ -341,16 +344,13 @@ def parse_expr_range(stream: TokenStream, expr: Expr) -> Expr: min_ = parse_num(stream) max_ = min_ if stream.eat(","): - if stream.next() != "}": - max_ = parse_num(stream) - else: - max_ = -1 + max_ = parse_num(stream) if stream.next() != "}" else -1 if not stream.eat("}"): stream.err("Unclosed braced range") return {"type": "range", "min": min_, "max": max_, "expr": expr} -def resolve_name(stream: TokenStream, name: str) -> List["NodeType"]: +def resolve_name(stream: TokenStream, name: str) -> list["NodeType"]: types = stream.node_types type = types.get(name) if type: @@ -395,13 +395,13 @@ def iteratee(type: "NodeType") -> Expr: class Edge(TypedDict): term: Optional["NodeType"] - to: Optional[int] + to: int | None def nfa( expr: Expr, -) -> List[List[Edge]]: - nfa_: List[List[Edge]] = [[]] +) -> list[list[Edge]]: + nfa_: list[list[Edge]] = [[]] def node() -> int: nonlocal nfa_ @@ -409,25 +409,27 @@ def node() -> int: return len(nfa_) - 1 def edge( - from_: int, to: Optional[int] = None, term: Optional["NodeType"] = None + from_: int, + to: int | None = None, + term: Optional["NodeType"] = None, ) -> Edge: nonlocal nfa_ edge: Edge = {"term": term, "to": to} nfa_[from_].append(edge) return edge - def connect(edges: List[Edge], to: int) -> None: + def connect(edges: list[Edge], to: int) -> None: for edge in edges: edge["to"] = to - def compile(expr: Expr, from_: int) -> List[Edge]: + def compile(expr: Expr, from_: int) -> list[Edge]: if expr["type"] == "choice": return list( reduce( lambda out, expr: [*out, *compile(expr, from_)], expr["exprs"], - cast(List[Edge], []), - ) + cast(list[Edge], []), + ), ) elif expr["type"] == "seq": i = 0 @@ -452,14 +454,14 @@ def compile(expr: Expr, from_: int) -> List[Edge]: return [edge(from_), *compile(expr["expr"], from_)] elif expr["type"] == "range": cur = from_ - for i in range(expr["min"]): + for _i in range(expr["min"]): next = node() connect(compile(expr["expr"], cur), next) cur = next if expr["max"] == -1: connect(compile(expr["expr"], cur), cur) else: - for i in range(expr["min"], expr["max"]): + for _i in range(expr["min"], expr["max"]): next = node() edge(cur, next) connect(compile(expr["expr"], cur), next) @@ -477,9 +479,9 @@ def cmp(a: int, b: int) -> int: def null_from( - nfa: List[List[Edge]], + nfa: list[list[Edge]], node: int, -) -> List[int]: +) -> list[int]: result = [] def scan(n: int) -> None: @@ -499,21 +501,21 @@ def scan(n: int) -> None: class DFAState(NamedTuple): state: "NodeType" - next: List[int] + next: list[int] -def dfa(nfa: List[List[Edge]]) -> ContentMatch: +def dfa(nfa: list[list[Edge]]) -> ContentMatch: labeled = {} - def explore(states: List[int]) -> ContentMatch: + def explore(states: list[int]) -> ContentMatch: nonlocal labeled - out: List[DFAState] = [] + out: list[DFAState] = [] for node in states: for item in nfa[node]: term, to = item.get("term"), item.get("to") if not term: continue - set: Optional[List[int]] = None + set: list[int] | None = None for t in out: if t[0] == term: set = t[1] @@ -530,7 +532,7 @@ def explore(states: List[int]) -> ContentMatch: states = out[i][1] find_by_key = ",".join(str(s) for s in states) state.next.append( - MatchEdge(out[i][0], labeled.get(find_by_key) or explore(states)) + MatchEdge(out[i][0], labeled.get(find_by_key) or explore(states)), ) return state @@ -555,6 +557,6 @@ def check_for_dead_ends(match: ContentMatch, stream: TokenStream) -> None: if dead: stream.err( f'Only non-generatable nodes ({", ".join(nodes)}) in a required ' - "position (see https://prosemirror.net/docs/guide/#generatable)" + "position (see https://prosemirror.net/docs/guide/#generatable)", ) i += 1 diff --git a/prosemirror/model/diff.py b/prosemirror/model/diff.py index e4f7c1a..2dbef41 100644 --- a/prosemirror/model/diff.py +++ b/prosemirror/model/diff.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import TYPE_CHECKING, TypedDict from prosemirror.utils import text_length @@ -13,7 +13,7 @@ class Diff(TypedDict): b: int -def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: +def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> int | None: i = 0 while True: if a.child_count == i or b.child_count == i: @@ -36,7 +36,9 @@ def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: ( index_a for ((index_a, char_a), (_, char_b)) in zip( - enumerate(child_a.text), enumerate(child_b.text) + enumerate(child_a.text), + enumerate(child_b.text), + strict=True, ) if char_a != char_b ), @@ -52,9 +54,7 @@ def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: i += 1 -def find_diff_end( - a: "Fragment", b: "Fragment", pos_a: int, pos_b: int -) -> Optional[Diff]: +def find_diff_end(a: "Fragment", b: "Fragment", pos_a: int, pos_b: int) -> Diff | None: i_a, i_b = a.child_count, b.child_count while True: if i_a == 0 or i_b == 0: @@ -94,7 +94,10 @@ def find_diff_end( if child_a.content.size or child_b.content.size: inner = find_diff_end( - child_a.content, child_b.content, pos_a - 1, pos_b - 1 + child_a.content, + child_b.content, + pos_a - 1, + pos_b - 1, ) if inner: return inner diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 3e2b72e..ffe49a1 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -1,18 +1,14 @@ +from collections.abc import Callable, Iterable, Sequence from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Dict, - Iterable, - List, Optional, - Sequence, Union, cast, ) -from prosemirror.utils import JSONList, text_length +from prosemirror.utils import JSON, JSONList, text_length if TYPE_CHECKING: from prosemirror.model.schema import Schema @@ -21,16 +17,16 @@ from .node import Node, TextNode -def retIndex(index: int, offset: int) -> Dict[str, int]: +def ret_index(index: int, offset: int) -> dict[str, int]: return {"index": index, "offset": offset} class Fragment: empty: ClassVar["Fragment"] - content: List["Node"] + content: list["Node"] size: int - def __init__(self, content: List["Node"], size: Optional[int] = None) -> None: + def __init__(self, content: list["Node"], size: int | None = None) -> None: self.content = content self.size = size if size is not None else sum(c.node_size for c in content) @@ -38,7 +34,7 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + f: Callable[["Node", int, Optional["Node"], int], bool | None], node_start: int = 0, parent: Optional["Node"] = None, ) -> None: @@ -63,7 +59,8 @@ def nodes_between( i += 1 def descendants( - self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + self, + f: Callable[["Node", int, Optional["Node"], int], bool | None], ) -> None: self.nodes_between(0, self.size, f) @@ -72,13 +69,16 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: text = [] separated = True def iteratee( - node: "Node", pos: int, _parent: Optional["Node"], _to: int + node: "Node", + pos: int, + _parent: Optional["Node"], + _to: int, ) -> None: nonlocal text nonlocal separated @@ -110,7 +110,8 @@ def append(self, other: "Fragment") -> "Fragment": self.content.copy(), 0, ) - assert last is not None and first is not None + assert last is not None + assert first is not None if pm_node.is_text(last) and last.same_markup(first): assert isinstance(first, pm_node.TextNode) content[len(content) - 1] = last.with_text(last.text + first.text) @@ -120,12 +121,12 @@ def append(self, other: "Fragment") -> "Fragment": i += 1 return Fragment(content, self.size + other.size) - def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": + def cut(self, from_: int, to: int | None = None) -> "Fragment": if to is None: to = self.size if from_ == 0 and to == self.size: return self - result: List["Node"] = [] + result: list["Node"] = [] size = 0 if to <= from_: return Fragment(result, size) @@ -137,7 +138,8 @@ def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": if pos < from_ or end > to: if pm_node.is_text(child): child = child.cut( - max(0, from_ - pos), min(text_length(child.text), to - pos) + max(0, from_ - pos), + min(text_length(child.text), to - pos), ) else: child = child.cut( @@ -150,7 +152,7 @@ def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": i += 1 return Fragment(result, size) - def cut_by_index(self, from_: int, to: Optional[int] = None) -> "Fragment": + def cut_by_index(self, from_: int, to: int | None = None) -> "Fragment": if from_ == to: return Fragment.empty if from_ == 0 and to == len(self.content): @@ -175,7 +177,7 @@ def add_to_end(self, node: "Node") -> "Fragment": def eq(self, other: "Fragment") -> bool: if len(self.content) != len(other.content): return False - return all(a.eq(b) for (a, b) in zip(self.content, other.content)) + return all(a.eq(b) for (a, b) in zip(self.content, other.content, strict=True)) @property def first_child(self) -> Optional["Node"]: @@ -207,7 +209,7 @@ def for_each(self, f: Callable[["Node", int, int], Any]) -> None: p += child.node_size i += 1 - def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: + def find_diff_start(self, other: "Fragment", pos: int = 0) -> int | None: from .diff import find_diff_start return find_diff_start(self, other, pos) @@ -215,8 +217,8 @@ def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: def find_diff_end( self, other: "Fragment", - pos: Optional[int] = None, - other_pos: Optional[int] = None, + pos: int | None = None, + other_pos: int | None = None, ) -> Optional["Diff"]: from .diff import find_diff_end @@ -226,13 +228,14 @@ def find_diff_end( other_pos = other.size return find_diff_end(self, other, pos, other_pos) - def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: + def find_index(self, pos: int, round: int = -1) -> dict[str, int]: if pos == 0: - return retIndex(0, pos) + return ret_index(0, pos) if pos == self.size: - return retIndex(len(self.content), pos) + return ret_index(len(self.content), pos) if pos > self.size or pos < 0: - raise ValueError(f"Position {pos} outside of fragment ({self})") + msg = f"Position {pos} outside of fragment ({self})" + raise ValueError(msg) i = 0 cur_pos = 0 while True: @@ -240,18 +243,18 @@ def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: end = cur_pos + cur.node_size if end >= pos: if end == pos or round > 0: - return retIndex(i + 1, end) - return retIndex(i, cur_pos) + return ret_index(i + 1, end) + return ret_index(i, cur_pos) i += 1 cur_pos = end - def to_json(self) -> Optional[JSONList]: + def to_json(self) -> JSONList | None: if self.content: return [item.to_json() for item in self.content] return None @classmethod - def from_json(cls, schema: "Schema[Any, Any]", value: Any) -> "Fragment": + def from_json(cls, schema: "Schema[Any, Any]", value: JSON) -> "Fragment": if not value: return cls.empty @@ -261,15 +264,16 @@ def from_json(cls, schema: "Schema[Any, Any]", value: Any) -> "Fragment": value = json.loads(value) if not isinstance(value, list): - raise ValueError("Invalid input for Fragment.from_json") + msg = "Invalid input for Fragment.from_json" + raise ValueError(msg) return cls([schema.node_from_json(item) for item in value]) @classmethod - def from_array(cls, array: List["Node"]) -> "Fragment": + def from_array(cls, array: list["Node"]) -> "Fragment": if not array: return cls.empty - joined: Optional[List["Node"]] = None + joined: list["Node"] | None = None size = 0 for i in range(len(array)): node = array[i] @@ -286,7 +290,8 @@ def from_array(cls, array: List["Node"]) -> "Fragment": @classmethod def from_( - cls, nodes: Union["Fragment", "Node", Sequence["Node"], None] + cls, + nodes: Union["Fragment", "Node", Sequence["Node"], None], ) -> "Fragment": if not nodes: return cls.empty @@ -296,7 +301,8 @@ def from_( return cls.from_array(list(nodes)) if hasattr(nodes, "attrs"): return cls([nodes], nodes.node_size) - raise ValueError(f"cannot convert {nodes!r} to a fragment") + msg = f"cannot convert {nodes!r} to a fragment" + raise ValueError(msg) def to_string_inner(self) -> str: return ", ".join([str(i) for i in self.content]) diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index e1760a1..994a2e9 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1,7 +1,8 @@ import itertools import re +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Literal, cast import lxml from lxml.cssselect import CSSSelector @@ -17,51 +18,51 @@ from .resolvedpos import ResolvedPos from .schema import MarkType, NodeType, Schema -WSType = Union[bool, Literal["full"], None] +WSType = bool | Literal["full"] | None @dataclass class DOMPosition: node: DOMNode offset: int - pos: Optional[int] = None + pos: int | None = None @dataclass(frozen=True) class ParseOptions: preserve_whitespace: WSType = None - find_positions: Optional[List[DOMPosition]] = None - from_: Optional[int] = None - to_: Optional[int] = None - top_node: Optional[Node] = None - top_match: Optional[ContentMatch] = None - context: Optional[ResolvedPos] = None - rule_from_node: Optional[Callable[[DOMNode], "ParseRule"]] = None - top_open: Optional[bool] = None + find_positions: list[DOMPosition] | None = None + from_: int | None = None + to_: int | None = None + top_node: Node | None = None + top_match: ContentMatch | None = None + context: ResolvedPos | None = None + rule_from_node: Callable[[DOMNode], "ParseRule"] | None = None + top_open: bool | None = None @dataclass class ParseRule: - tag: Optional[str] - namespace: Optional[str] - style: Optional[str] - priority: Optional[int] - consuming: Optional[bool] - context: Optional[str] - node: Optional[str] - mark: Optional[str] - clear_mark: Optional[Callable[[Mark], bool]] - ignore: Optional[bool] - close_parent: Optional[bool] - skip: Optional[bool] - attrs: Optional[Attrs] - get_attrs: Optional[Callable[[DOMNode], Union[Attrs, Literal[False], None]]] - content_element: Union[str, DOMNode, Callable[[DOMNode], DOMNode], None] - get_content: Optional[Callable[[DOMNode, Schema[Any, Any]], Fragment]] + tag: str | None + namespace: str | None + style: str | None + priority: int | None + consuming: bool | None + context: str | None + node: str | None + mark: str | None + clear_mark: Callable[[Mark], bool] | None + ignore: bool | None + close_parent: bool | None + skip: bool | None + attrs: Attrs | None + get_attrs: Callable[[DOMNode], Attrs | Literal[False] | None] | None + content_element: str | DOMNode | Callable[[DOMNode], DOMNode] | None + get_content: Callable[[DOMNode, Schema[Any, Any]], Fragment] | None preserve_whitespace: WSType @classmethod - def from_json(cls, data: Dict[str, Any]) -> "ParseRule": + def from_json(cls, data: dict[str, Any]) -> "ParseRule": return ParseRule( data.get("tag"), data.get("namespace"), @@ -84,14 +85,14 @@ def from_json(cls, data: Dict[str, Any]) -> "ParseRule": class DOMParser: - _tags: List[ParseRule] - _styles: List[ParseRule] + _tags: list[ParseRule] + _styles: list[ParseRule] _normalize_lists: bool schema: Schema[Any, Any] - rules: List[ParseRule] + rules: list[ParseRule] - def __init__(self, schema: Schema[Any, Any], rules: List[ParseRule]) -> None: + def __init__(self, schema: Schema[Any, Any], rules: list[ParseRule]) -> None: self.schema = schema self.rules = rules self._tags = [rule for rule in rules if rule.tag is not None] @@ -107,7 +108,9 @@ def __init__(self, schema: Schema[Any, Any], rules: List[ParseRule]) -> None: ]) def parse( - self, dom_: lxml.html.HtmlElement, options: Optional[ParseOptions] = None + self, + dom_: lxml.html.HtmlElement, + options: ParseOptions | None = None, ) -> Node: if options is None: options = ParseOptions() @@ -132,9 +135,7 @@ def parse( return cast(Node, context.finish()) - def parse_slice( - self, dom_: DOMNode, options: Optional[ParseOptions] = None - ) -> Slice: + def parse_slice(self, dom_: DOMNode, options: ParseOptions | None = None) -> Slice: if options is None: options = ParseOptions(preserve_whitespace=True) @@ -145,8 +146,11 @@ def parse_slice( return Slice.max_open(cast(Fragment, context.finish())) def match_tag( - self, dom_: DOMNode, context: "ParseContext", after: Optional[ParseRule] = None - ) -> Optional[ParseRule]: + self, + dom_: DOMNode, + context: "ParseContext", + after: ParseRule | None = None, + ) -> ParseRule | None: try: i = self._tags.index(after) + 1 if after is not None else 0 except ValueError: @@ -177,8 +181,8 @@ def match_style( prop: str, value: str, context: "ParseContext", - after: Optional[ParseRule] = None, - ) -> Optional[ParseRule]: + after: ParseRule | None = None, + ) -> ParseRule | None: i = self._styles.index(after) + 1 if after is not None else 0 for rule in self._styles[i:]: @@ -208,8 +212,8 @@ def match_style( return None @classmethod - def schema_rules(cls, schema: Schema[Any, Any]) -> List[ParseRule]: - result: List[ParseRule] = [] + def schema_rules(cls, schema: Schema[Any, Any]) -> list[ParseRule]: + result: list[ParseRule] = [] def insert(rule: ParseRule) -> None: priority = rule.priority if rule.priority is not None else 50 @@ -255,13 +259,14 @@ def insert(rule: ParseRule) -> None: def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": if "dom_parser" not in schema.cached: schema.cached["dom_parser"] = DOMParser( - schema, DOMParser.schema_rules(schema) + schema, + DOMParser.schema_rules(schema), ) return cast("DOMParser", schema.cached["dom_parser"]) -BLOCK_TAGS: Dict[str, bool] = { +BLOCK_TAGS: dict[str, bool] = { "address": True, "article": True, "aside": True, @@ -296,7 +301,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": "ul": True, } -IGNORE_TAGS: Dict[str, bool] = { +IGNORE_TAGS: dict[str, bool] = { "head": True, "noscript": True, "object": True, @@ -305,7 +310,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": "title": True, } -LIST_TAGS: Dict[str, bool] = {"ol": True, "ul": True} +LIST_TAGS: dict[str, bool] = {"ol": True, "ul": True} OPT_PRESERVE_WS = 1 @@ -314,7 +319,9 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": def ws_options_for( - _type: Optional[NodeType], preserve_whitespace: WSType, base: int + _type: NodeType | None, + preserve_whitespace: WSType, + base: int, ) -> int: if preserve_whitespace is not None: return (OPT_PRESERVE_WS if preserve_whitespace else 0) | ( @@ -329,29 +336,29 @@ def ws_options_for( class NodeContext: - match: Optional[ContentMatch] - content: List[Node] + match: ContentMatch | None + content: list[Node] - active_marks: List[Mark] - stash_marks: List[Mark] + active_marks: list[Mark] + stash_marks: list[Mark] - type: Optional[NodeType] + type: NodeType | None options: int - attrs: Optional[Attrs] - marks: List[Mark] - pending_marks: List[Mark] + attrs: Attrs | None + marks: list[Mark] + pending_marks: list[Mark] solid: bool def __init__( self, - _type: Optional[NodeType], - attrs: Optional[Attrs], - marks: List[Mark], - pending_marks: List[Mark], + _type: NodeType | None, + attrs: Attrs | None, + marks: list[Mark], + pending_marks: list[Mark], solid: bool, - match: Optional[ContentMatch], + match: ContentMatch | None, options: int, ) -> None: self.type = _type @@ -375,7 +382,7 @@ def __init__( self.active_marks = Mark.none self.stash_marks = [] - def find_wrapping(self, node: Node) -> Optional[List[NodeType]]: + def find_wrapping(self, node: Node) -> list[NodeType] | None: if not self.match: if not self.type: return [] @@ -399,10 +406,10 @@ def find_wrapping(self, node: Node) -> Optional[List[NodeType]]: return self.match.find_wrapping(node.type) - def finish(self, open_end: bool) -> Union[Node, Fragment]: + def finish(self, open_end: bool) -> Node | Fragment: if not self.options & OPT_PRESERVE_WS: try: - last: Optional[Node] = self.content[-1] + last: Node | None = self.content[-1] except IndexError: last = None @@ -419,15 +426,15 @@ def finish(self, open_end: bool) -> Union[Node, Fragment]: content = Fragment.from_(self.content) if not open_end and self.match is not None: content = content.append( - cast(Fragment, self.match.fill_before(Fragment.empty, True)) + cast(Fragment, self.match.fill_before(Fragment.empty, True)), ) return ( self.type.create(self.attrs, content, self.marks) if self.type else content ) - def pop_from_stash_mark(self, mark: Mark) -> Optional[Mark]: - found_mark: Optional[Mark] = None + def pop_from_stash_mark(self, mark: Mark) -> Mark | None: + found_mark: Mark | None = None for stash_mark in self.stash_marks[::-1]: if mark.eq(stash_mark): found_mark = stash_mark @@ -458,9 +465,9 @@ def inline_context(self, node: DOMNode) -> bool: class ParseContext: open: int = 0 - find: Optional[List[DOMPosition]] + find: list[DOMPosition] | None needs_block: bool - nodes: List[NodeContext] + nodes: list[NodeContext] options: ParseOptions is_open: bool parser: DOMParser @@ -487,7 +494,13 @@ def __init__(self, parser: DOMParser, options: ParseOptions, is_open: bool) -> N ) elif is_open: top_context = NodeContext( - None, None, Mark.none, Mark.none, True, None, top_options + None, + None, + Mark.none, + Mark.none, + True, + None, + top_options, ) else: top_context = NodeContext( @@ -566,7 +579,8 @@ def add_text_node(self, dom_: DOMNode) -> None: or ( node_before.is_text and re.search( - r"[ \t\r\n\u000c]$", cast(TextNode, node_before).text + r"[ \t\r\n\u000c]$", + cast(TextNode, node_before).text, ) is not None ) @@ -585,9 +599,7 @@ def add_text_node(self, dom_: DOMNode) -> None: else: self.find_inside(dom_) - def add_element( - self, dom_: DOMNode, match_after: Optional[ParseRule] = None - ) -> None: + def add_element(self, dom_: DOMNode, match_after: ParseRule | None = None) -> None: name = dom_.tag.lower() if name in LIST_TAGS and self.parser.normalize_lists: @@ -635,7 +647,9 @@ def add_element( else: self.add_element_by_rule( - dom_, rule, rule_id if rule.consuming is False else None + dom_, + rule, + rule_id if rule.consuming is False else None, ) def leaf_fallback(self, dom_: DOMNode) -> None: @@ -650,12 +664,12 @@ def ignore_fallback(self, dom_: DOMNode) -> None: ): self.find_place(self.parser.schema.text("-")) - def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark]]]: - add: List[Mark] = Mark.none - remove: List[Mark] = Mark.none + def read_styles(self, styles: list[str]) -> tuple[list[Mark], list[Mark]] | None: + add: list[Mark] = Mark.none + remove: list[Mark] = Mark.none for i in range(0, len(styles), 2): - after: Optional[ParseRule] = None + after: ParseRule | None = None while True: rule = self.parser.match_style(styles[i], styles[i + 1], self, after) if not rule: @@ -681,11 +695,14 @@ def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark return add, remove def add_element_by_rule( - self, dom_: DOMNode, rule: ParseRule, continue_after: Optional[ParseRule] = None + self, + dom_: DOMNode, + rule: ParseRule, + continue_after: ParseRule | None = None, ) -> None: sync: bool = False - mark: Optional[Mark] = None - node_type: Optional[NodeType] = None + mark: Mark | None = None + node_type: NodeType | None = None if rule.node is not None: node_type = self.parser.schema.nodes[rule.node] @@ -707,7 +724,7 @@ def add_element_by_rule( elif rule.get_content is not None: self.find_inside(dom_) rule.get_content(dom_, self.parser.schema).for_each( - lambda node, offset, index: self.insert_node(node) + lambda node, offset, index: self.insert_node(node), ) else: content_dom = dom_ @@ -731,8 +748,8 @@ def add_element_by_rule( def add_all( self, parent: DOMNode, - start_index: Optional[int] = None, - end_index: Optional[int] = None, + start_index: int | None = None, + end_index: int | None = None, ) -> None: index = start_index if start_index is not None else 0 @@ -753,8 +770,8 @@ def add_all( self.find_at_point(parent, index) def find_place(self, node: Node) -> bool: - route: Optional[List[NodeType]] = None - sync: Optional[NodeContext] = None + route: list[NodeType] | None = None + sync: NodeContext | None = None depth = self.open while depth >= 0: @@ -810,7 +827,10 @@ def insert_node(self, node: Node) -> bool: return False def enter( - self, type_: NodeType, attrs: Optional[Attrs] = None, preserve_ws: WSType = None + self, + type_: NodeType, + attrs: Attrs | None = None, + preserve_ws: WSType = None, ) -> bool: ok = self.find_place(type_.create(attrs)) if ok: @@ -821,7 +841,7 @@ def enter( def enter_inner( self, type_: NodeType, - attrs: Optional[Attrs] = None, + attrs: Attrs | None = None, solid: bool = False, preserve_ws: WSType = None, ) -> None: @@ -840,8 +860,14 @@ def enter_inner( self.nodes.append( NodeContext( - type_, attrs, top.active_marks, top.pending_marks, solid, None, options - ) + type_, + attrs, + top.active_marks, + top.pending_marks, + solid, + None, + options, + ), ) self.open += 1 @@ -852,13 +878,13 @@ def close_extra(self, open_end: bool = False) -> None: if i > self.open: while i > self.open: self.nodes[i - 1].content.append( - cast(Node, self.nodes[i].finish(open_end)) + cast(Node, self.nodes[i].finish(open_end)), ) i -= 1 self.nodes = self.nodes[: self.open + 1] - def finish(self) -> Union[Node, Fragment]: + def finish(self) -> Node | Fragment: self.open = 0 self.close_extra(self.is_open) return self.nodes[0].finish(self.is_open or bool(self.options.top_open)) @@ -953,7 +979,7 @@ def match(i: int, depth: int) -> bool: return False else: if depth > 0 or (depth == 0 and use_root): - next: Optional[NodeType] = self.nodes[depth].type + next: NodeType | None = self.nodes[depth].type elif option is not None and depth >= min_depth: next = option.node(depth - min_depth).type else: @@ -974,7 +1000,7 @@ def match(i: int, depth: int) -> bool: return match(len(parts) - 1, self.open) - def textblock_from_context(self) -> Optional[NodeType]: + def textblock_from_context(self) -> NodeType | None: context = self.options.context if context: @@ -988,15 +1014,15 @@ def textblock_from_context(self) -> Optional[NodeType]: if ( default is not None - and default.is_text_block + and default.is_textblock and default.default_attrs ): return default d -= 1 - for name, type_ in self.parser.schema.nodes.items(): - if type_.is_text_block and type_.default_attrs: + for type_ in self.parser.schema.nodes.values(): + if type_.is_textblock and type_.default_attrs: return type_ return None @@ -1036,7 +1062,7 @@ def remove_pending_mark(self, mark: Mark, upto: NodeContext) -> None: def normalize_list(dom_: DOMNode) -> None: child = next(iter(dom_)) - prev_item: Optional[DOMNode] = None + prev_item: DOMNode | None = None while child is not None: name = child.tag.lower() if get_node_type(child) == 1 else None @@ -1058,9 +1084,9 @@ def matches(dom_: DOMNode, selector_str: str) -> bool: return bool(dom_ in selector(dom_)) # type: ignore[operator] -def parse_styles(style: str) -> List[str]: +def parse_styles(style: str) -> list[str]: regex = r"\s*([\w-]+)\s*:\s*([^;]+)" - result: List[str] = [] + result: list[str] = [] for m in re.findall(regex, style): result.append(m[0]) @@ -1072,14 +1098,14 @@ def parse_styles(style: str) -> List[str]: def mark_may_apply(mark_type: MarkType, node_type: NodeType) -> bool: nodes = node_type.schema.nodes - for name, parent in nodes.items(): + for parent in nodes.values(): if not parent.allows_mark_type(mark_type): continue - seen: List[ContentMatch] = [] + seen: list[ContentMatch] = [] def scan(match: ContentMatch) -> bool: - seen.append(match) + seen.append(match) # noqa: B023 i = 0 while i < match.edge_count: result = match.edge(i) @@ -1088,7 +1114,7 @@ def scan(match: ContentMatch) -> bool: if _type == node_type: return True - if _next not in seen and scan(_next): + if _next not in seen and scan(_next): # noqa: B023 return True i += 1 @@ -1100,7 +1126,7 @@ def scan(match: ContentMatch) -> bool: return False -def find_same_mark_in_set(mark: Mark, mark_set: List[Mark]) -> Optional[Mark]: +def find_same_mark_in_set(mark: Mark, mark_set: list[Mark]) -> Mark | None: for comp in mark_set: if mark.eq(comp): return comp @@ -1109,18 +1135,16 @@ def find_same_mark_in_set(mark: Mark, mark_set: List[Mark]) -> Optional[Mark]: def node_contains(node: DOMNode, find: DOMNode) -> bool: - for child_node in node.iterdescendants(): - if child_node == find: - return True - - return False + return any(child_node == find for child_node in node.iterdescendants()) def compare_document_position(node1: DOMNode, node2: DOMNode) -> int: if not isinstance(node1, lxml.etree._Element) or not isinstance( - node2, lxml.etree._Element + node2, + lxml.etree._Element, ): - raise ValueError("Both arguments must be lxml Element objects.") + msg = "Both arguments must be lxml Element objects." + raise ValueError(msg) tree = lxml.etree.ElementTree(node1) @@ -1151,7 +1175,8 @@ def compare_document_position(node1: DOMNode, node2: DOMNode) -> int: def get_node_type(element: DOMNode) -> int: if not isinstance(element, lxml.etree._Element): - raise ValueError("The provided element is not an lxml HtmlElement.") + msg = "The provided element is not an lxml HtmlElement." + raise ValueError(msg) if isinstance(element, lxml.etree._Comment): return 8 # Comment node type diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index fabc247..41e878d 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Final, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Final, Union, cast from prosemirror.utils import Attrs, JSONDict @@ -8,14 +8,14 @@ class Mark: - none: Final[List["Mark"]] = [] + none: Final[list["Mark"]] = [] def __init__(self, type: "MarkType", attrs: Attrs) -> None: self.type = type self.attrs = attrs - def add_to_set(self, set: List["Mark"]) -> List["Mark"]: - copy: Optional[List["Mark"]] = None + def add_to_set(self, set: list["Mark"]) -> list["Mark"]: + copy: list["Mark"] | None = None placed = False for i in range(len(set)): other = set[i] @@ -40,10 +40,10 @@ def add_to_set(self, set: List["Mark"]) -> List["Mark"]: copy.append(self) return copy - def remove_from_set(self, set: List["Mark"]) -> List["Mark"]: + def remove_from_set(self, set: list["Mark"]) -> list["Mark"]: return [item for item in set if not item.eq(self)] - def is_in_set(self, set: List["Mark"]) -> bool: + def is_in_set(self, set: list["Mark"]) -> bool: return any(item.eq(self) for item in set) def eq(self, other: "Mark") -> bool: @@ -61,23 +61,25 @@ def from_json( json_data: JSONDict, ) -> "Mark": if not json_data: - raise ValueError("Invalid input for Mark.fromJSON") + msg = "Invalid input for Mark.fromJSON" + raise ValueError(msg) name = json_data["type"] type = schema.marks.get(name) if not type: - raise ValueError(f"There is no mark type {name} in this schema") - return type.create(cast(Optional[JSONDict], json_data.get("attrs"))) + msg = f"There is no mark type {name} in this schema" + raise ValueError(msg) + return type.create(cast(JSONDict | None, json_data.get("attrs"))) @classmethod - def same_set(cls, a: List["Mark"], b: List["Mark"]) -> bool: + def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: if a == b: return True if len(a) != len(b): return False - return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b)) + return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b, strict=True)) @classmethod - def set_from(cls, marks: Union[List["Mark"], "Mark", None]) -> List["Mark"]: + def set_from(cls, marks: Union[list["Mark"], "Mark", None]) -> list["Mark"]: if not marks: return cls.none if isinstance(marks, Mark): diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 55d17bc..1135d4e 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,7 +1,6 @@ import copy -from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypedDict, Union, cast - -from typing_extensions import TypeGuard +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional, TypedDict, TypeGuard, Union, cast from prosemirror.utils import Attrs, JSONDict, text_length @@ -30,8 +29,8 @@ def __init__( self, type: "NodeType", attrs: "Attrs", - content: Optional[Fragment], - marks: List[Mark], + content: Fragment | None, + marks: list[Mark], ) -> None: self.type = type self.attrs = attrs @@ -59,13 +58,14 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + f: Callable[["Node", int, Optional["Node"], int], bool | None], start_pos: int = 0, ) -> None: self.content.nodes_between(from_, to, f, start_pos, self) def descendants( - self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + self, + f: Callable[["Node", int, Optional["Node"], int], bool | None], ) -> None: self.nodes_between(0, self.content.size, f) @@ -80,7 +80,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: return self.content.text_between(from_, to, block_separator, leaf_text) @@ -104,7 +104,7 @@ def has_markup( self, type: "NodeType", attrs: Optional["Attrs"] = None, - marks: Optional[List[Mark]] = None, + marks: list[Mark] | None = None, ) -> bool: return ( self.type.name == type.name @@ -112,23 +112,26 @@ def has_markup( and (Mark.same_set(self.marks, marks or Mark.none)) ) - def copy(self, content: Optional[Fragment] = None) -> "Node": + def copy(self, content: Fragment | None = None) -> "Node": if content == self.content: return self return self.__class__(self.type, self.attrs, content, self.marks) - def mark(self, marks: List[Mark]) -> "Node": + def mark(self, marks: list[Mark]) -> "Node": if marks == self.marks: return self return self.__class__(self.type, self.attrs, self.content, marks) - def cut(self, from_: int, to: Optional[int] = None) -> "Node": + def cut(self, from_: int, to: int | None = None) -> "Node": if from_ == 0 and to == self.content.size: return self return self.copy(self.content.cut(from_, to)) def slice( - self, from_: int, to: Optional[int] = None, include_parents: bool = False + self, + from_: int, + to: int | None = None, + include_parents: bool = False, ) -> Slice: if to is None: to = self.content.size @@ -184,13 +187,19 @@ def resolve_no_cache(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve(self, pos) def range_has_mark( - self, from_: int, to: int, type: Union["Mark", "MarkType"] + self, + from_: int, + to: int, + type: Union["Mark", "MarkType"], ) -> bool: found = False if to > from_: def iteratee( - node: "Node", pos: int, parent: Optional["Node"], index: int + node: "Node", + pos: int, + parent: Optional["Node"], + index: int, ) -> bool: nonlocal found if type.is_in_set(node.marks): @@ -205,8 +214,8 @@ def is_block(self) -> bool: return self.type.is_block @property - def is_text_block(self) -> bool: - return self.type.is_text_block + def is_textblock(self) -> bool: + return self.type.is_textblock @property def inline_content(self) -> bool: @@ -229,11 +238,7 @@ def is_atom(self) -> bool: return self.type.is_atom def __str__(self) -> str: - to_debug_string = ( - self.type.spec["toDebugString"] - if "toDebugString" in self.type.spec - else None - ) + to_debug_string = self.type.spec.get("toDebugString", None) if to_debug_string: return to_debug_string(self) name = self.type.name @@ -247,7 +252,8 @@ def __repr__(self) -> str: def content_match_at(self, index: int) -> "ContentMatch": match = self.type.content_match.match_fragment(self.content, 0, index) if not match: - raise ValueError("Called contentMatchAt on a node with invalid content") + msg = "Called contentMatchAt on a node with invalid content" + raise ValueError(msg) return match def can_replace( @@ -256,12 +262,12 @@ def can_replace( to: int, replacement: Fragment = Fragment.empty, start: int = 0, - end: Optional[int] = None, + end: int | None = None, ) -> bool: if end is None: end = replacement.child_count one = self.content_match_at(from_).match_fragment(replacement, start, end) - two: Optional["ContentMatch"] = None + two: "ContentMatch" | None = None if one: two = one.match_fragment(self.content, to) if not two or not two.valid_end: @@ -272,12 +278,16 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: "NodeType", marks: Optional[List[Mark]] = None + self, + from_: int, + to: int, + type: "NodeType", + marks: list[Mark] | None = None, ) -> bool: if marks and not self.type.allows_marks(marks): return False start = self.content_match_at(from_).match_type(type) - end: Optional["ContentMatch"] = None + end: "ContentMatch" | None = None if start: end = start.match_fragment(self.content, to) return end.valid_end if end else False @@ -290,17 +300,17 @@ def can_append(self, other: "Node") -> bool: def check(self) -> None: if not self.type.valid_content(self.content): - raise ValueError( - f"Invalid content for node {self.type.name}: {str(self.content)[:50]}" - ) + msg = f"Invalid content for node {self.type.name}: {str(self.content)[:50]}" + raise ValueError(msg) copy = Mark.none for mark in self.marks: copy = mark.add_to_set(copy) if not Mark.same_set(copy, self.marks): - raise ValueError( + msg = ( f"Invalid collection of marks for node {self.type.name}:" f" {[m.type.name for m in self.marks]!r}" ) + raise ValueError(msg) def iteratee(node: "Node", offset: int, index: int) -> None: node.check() @@ -327,26 +337,28 @@ def to_json(self) -> JSONDict: return obj @classmethod - def from_json( - cls, schema: "Schema[Any, Any]", json_data: Union[JSONDict, str] - ) -> "Node": + def from_json(cls, schema: "Schema[Any, Any]", json_data: JSONDict | str) -> "Node": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not json_data: - raise ValueError("Invalid input for Node.from_json") + msg = "Invalid input for Node.from_json" + raise ValueError(msg) marks = None if json_data.get("marks"): if not isinstance(json_data["marks"], list): - raise ValueError("Invalid mark data for Node.fromJSON") + msg = "Invalid mark data for Node.fromJSON" + raise ValueError(msg) marks = [schema.mark_from_json(item) for item in json_data["marks"]] if json_data["type"] == "text": return schema.text(str(json_data["text"]), marks) content = Fragment.from_json(schema, json_data.get("content")) return schema.node_type(str(json_data["type"])).create( - cast("Attrs", json_data.get("attrs")), content, marks + cast("Attrs", json_data.get("attrs")), + content, + marks, ) @@ -356,21 +368,18 @@ def __init__( type: "NodeType", attrs: "Attrs", content: str, - marks: List[Mark], + marks: list[Mark], ) -> None: super().__init__(type, attrs, None, marks) if not content: - raise ValueError("Empty text nodes are not allowed") + msg = "Empty text nodes are not allowed" + raise ValueError(msg) self.text = content def __str__(self) -> str: import json - to_debug_string = ( - self.type.spec["toDebugString"] - if "toDebugString" in self.type.spec - else None - ) + to_debug_string = self.type.spec.get("toDebugString", None) if to_debug_string: return to_debug_string(self) return wrap_marks(self.marks, json.dumps(self.text)) @@ -384,7 +393,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: return self.text[from_:to] @@ -392,7 +401,7 @@ def text_between( def node_size(self) -> int: return text_length(self.text) - def mark(self, marks: List[Mark]) -> "TextNode": + def mark(self, marks: list[Mark]) -> "TextNode": return ( self if marks == self.marks @@ -404,13 +413,13 @@ def with_text(self, text: str) -> "TextNode": return self return TextNode(self.type, self.attrs, text, self.marks) - def cut(self, from_: int = 0, to: Optional[int] = None) -> "TextNode": + def cut(self, from_: int = 0, to: int | None = None) -> "TextNode": if to is None: to = text_length(self.text) if from_ == 0 and to == text_length(self.text): return self substring = self.text.encode("utf-16-le")[2 * from_ : 2 * to].decode( - "utf-16-le" + "utf-16-le", ) return self.with_text(substring) @@ -423,7 +432,7 @@ def to_json( return {**super().to_json(), "text": self.text} -def wrap_marks(marks: List[Mark], str: str) -> str: +def wrap_marks(marks: list[Mark], str: str) -> str: i = len(marks) - 1 while i >= 0: str = marks[i].type.name + "(" + str + ")" diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index a6cc4dd..f1922a3 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast from prosemirror.utils import JSONDict @@ -22,11 +22,13 @@ def remove_range(content: Fragment, from_: int, to: int) -> Fragment: index_to, offset_to = to_index_info["index"], to_index_info["offset"] if offset == from_ or cast("Node", child).is_text: if offset_to != to and not content.child(index_to).is_text: - raise ValueError("removing non-flat range") + msg = "removing non-flat range" + raise ValueError(msg) return content.cut(0, from_).append(content.cut(to)) assert child if index != index_to: - raise ValueError("removing non-flat range") + msg = "removing non-flat range" + raise ValueError(msg) return content.replace_child( index, child.copy(remove_range(child.content, from_ - offset - 1, to - offset - 1)), @@ -34,8 +36,11 @@ def remove_range(content: Fragment, from_: int, to: int) -> Fragment: def insert_into( - content: Fragment, dist: int, insert: Fragment, parent: Optional["Node"] -) -> Optional[Fragment]: + content: Fragment, + dist: int, + insert: Fragment, + parent: Optional["Node"], +) -> Fragment | None: a = content.find_index(dist) index, offset = a["index"], a["offset"] child = content.maybe_child(index) @@ -85,7 +90,7 @@ def eq(self, other: "Slice") -> bool: def __str__(self) -> str: return f"{self.content}({self.open_start},{self.open_end})" - def to_json(self) -> Optional[JSONDict]: + def to_json(self) -> JSONDict | None: if not self.content.size: return None json: JSONDict = {"content": self.content.to_json()} @@ -105,14 +110,15 @@ def to_json(self) -> Optional[JSONDict]: def from_json( cls, schema: "Schema[Any, Any]", - json_data: Optional[JSONDict], + json_data: JSONDict | None, ) -> "Slice": if not json_data: return cls.empty open_start = json_data.get("openStart", 0) or 0 open_end = json_data.get("openEnd", 0) or 0 if not isinstance(open_start, int) or not isinstance(open_end, int): - raise ValueError("invalid input for Slice.from_json") + msg = "invalid input for Slice.from_json" + raise ValueError(msg) return cls( Fragment.from_json(schema, json_data.get("content")), open_start, @@ -139,14 +145,19 @@ def max_open(cls, fragment: Fragment, open_isolating: bool = True) -> "Slice": def replace(from_: "ResolvedPos", to: "ResolvedPos", slice: Slice) -> "Node": if slice.open_start > from_.depth: - raise ReplaceError("Inserted content deeper than insertion position") + msg = "Inserted content deeper than insertion position" + raise ReplaceError(msg) if from_.depth - slice.open_start != to.depth - slice.open_end: - raise ReplaceError("Inconsistent open depths") + msg = "Inconsistent open depths" + raise ReplaceError(msg) return replace_outer(from_, to, slice, 0) def replace_outer( - from_: "ResolvedPos", to: "ResolvedPos", slice: Slice, depth: int + from_: "ResolvedPos", + to: "ResolvedPos", + slice: Slice, + depth: int, ) -> "Node": index = from_.index(depth) node = from_.node(depth) @@ -177,7 +188,8 @@ def replace_outer( def check_join(main: "Node", sub: "Node") -> None: if not sub.type.compatible_content(main.type): - raise ReplaceError(f"Cannot join {sub.type.name} onto {main.type.name}") + msg = f"Cannot join {sub.type.name} onto {main.type.name}" + raise ReplaceError(msg) def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": @@ -186,7 +198,7 @@ def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": return node -def add_node(child: "Node", target: List["Node"]) -> None: +def add_node(child: "Node", target: list["Node"]) -> None: last = len(target) - 1 if last >= 0 and pm_node.is_text(child) and child.same_markup(target[last]): target[last] = child.with_text(cast("TextNode", target[last]).text + child.text) @@ -198,7 +210,7 @@ def add_range( start: Optional["ResolvedPos"], end: Optional["ResolvedPos"], depth: int, - target: List["Node"], + target: list["Node"], ) -> None: node = cast("ResolvedPos", end or start).node(depth) start_index = 0 @@ -220,7 +232,8 @@ def add_range( def close(node: "Node", content: Fragment) -> "Node": if not node.type.valid_content(content): - raise ReplaceError(f"Invalid content for node {node.type.name}") + msg = f"Invalid content for node {node.type.name}" + raise ReplaceError(msg) return node.copy(content) @@ -233,7 +246,7 @@ def replace_three_way( ) -> Fragment: open_start = joinable(from_, start, depth + 1) if from_.depth > depth else None open_end = joinable(end, to, depth + 1) if to.depth > depth else None - content: List["Node"] = [] + content: list["Node"] = [] add_range(None, from_, depth, content) if open_start and open_end and start.index(depth) == end.index(depth): check_join(open_start, open_end) @@ -244,7 +257,8 @@ def replace_three_way( else: if open_start: add_node( - close(open_start, replace_two_way(from_, start, depth + 1)), content + close(open_start, replace_two_way(from_, start, depth + 1)), + content, ) add_range(start, end, depth, content) if open_end: @@ -254,7 +268,7 @@ def replace_three_way( def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Fragment: - content: List["Node"] = [] + content: list["Node"] = [] add_range(None, from_, depth, content) if from_.depth > depth: type = joinable(from_, to, depth + 1) @@ -264,8 +278,9 @@ def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Frag def prepare_slice_for_replace( - slice: Slice, along: "ResolvedPos" -) -> Dict[str, "ResolvedPos"]: + slice: Slice, + along: "ResolvedPos", +) -> dict[str, "ResolvedPos"]: extra = along.depth - slice.open_start parent = along.node(extra) node = parent.copy(slice.content) diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index dc5e7ec..2aeba5a 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast +from collections.abc import Callable +from typing import TYPE_CHECKING, Optional, Union, cast from .mark import Mark @@ -8,14 +9,17 @@ class ResolvedPos: def __init__( - self, pos: int, path: List[Union["Node", int]], parent_offset: int + self, + pos: int, + path: list[Union["Node", int]], + parent_offset: int, ) -> None: self.pos = pos self.path = path self.depth = int(len(path) / 3 - 1) self.parent_offset = parent_offset - def resolve_depth(self, val: Optional[int] = None) -> int: + def resolve_depth(self, val: int | None = None) -> int: if val is None: return self.depth return self.depth + val if val < 0 else val @@ -31,7 +35,7 @@ def doc(self) -> "Node": def node(self, depth: int) -> "Node": return cast("Node", self.path[self.resolve_depth(depth) * 3]) - def index(self, depth: Optional[int] = None) -> int: + def index(self, depth: int | None = None) -> int: return cast(int, self.path[self.resolve_depth(depth) * 3 + 1]) def index_after(self, depth: int) -> int: @@ -40,26 +44,28 @@ def index_after(self, depth: int) -> int: 0 if depth == self.depth and not self.text_offset else 1 ) - def start(self, depth: Optional[int] = None) -> int: + def start(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) return 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 - def end(self, depth: Optional[int] = None) -> int: + def end(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) return self.start(depth) + self.node(depth).content.size - def before(self, depth: Optional[int] = None) -> int: + def before(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) if not depth: - raise ValueError("There is no position before the top level node") + msg = "There is no position before the top level node" + raise ValueError(msg) return ( self.pos if depth == self.depth + 1 else cast(int, self.path[depth * 3 - 1]) ) - def after(self, depth: Optional[int] = None) -> int: + def after(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) if not depth: - raise ValueError("There is no position after the top level node") + msg = "There is no position after the top level node" + raise ValueError(msg) return ( self.pos if depth == self.depth + 1 @@ -89,7 +95,7 @@ def node_before(self) -> Optional["Node"]: return self.parent.child(index).cut(0, d_off) return None if index == 0 else self.parent.child(index - 1) - def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: + def pos_at_index(self, index: int, depth: int | None = None) -> int: depth = self.resolve_depth(depth) node = cast("Node", self.path[depth * 3]) pos = 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 @@ -97,7 +103,7 @@ def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: pos += node.child(i).node_size return pos - def marks(self) -> List["Mark"]: + def marks(self) -> list["Mark"]: parent = self.parent index = self.index() if parent.content.size == 0: @@ -119,7 +125,7 @@ def marks(self) -> List["Mark"]: i += 1 return marks - def marks_across(self, end: "ResolvedPos") -> Optional[List["Mark"]]: + def marks_across(self, end: "ResolvedPos") -> list["Mark"] | None: after = self.parent.maybe_child(self.index()) if not after or not after.is_inline: return None @@ -146,7 +152,7 @@ def shared_depth(self, pos: int) -> int: def block_range( self, other: Optional["ResolvedPos"] = None, - pred: Optional[Callable[["Node"], bool]] = None, + pred: Callable[["Node"], bool] | None = None, ) -> Optional["NodeRange"]: if other is None: other = self @@ -180,8 +186,9 @@ def __str__(self) -> str: @classmethod def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": if not (pos >= 0 and pos <= doc.content.size): - raise ValueError(f"Position {pos} out of range") - path: List[Union["Node", int]] = [] + msg = f"Position {pos} out of range" + raise ValueError(msg) + path: list["Node" | int] = [] start = 0 parent_offset = pos node = doc diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index f21914f..1ab5ab2 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -1,17 +1,15 @@ +from collections.abc import Callable from typing import ( Any, - Callable, - Dict, Generic, - List, Literal, Optional, + TypeAlias, TypeVar, - Union, cast, ) -from typing_extensions import NotRequired, TypeAlias, TypedDict +from typing_extensions import NotRequired, TypedDict from prosemirror.model.content import ContentMatch from prosemirror.model.fragment import Fragment @@ -20,7 +18,7 @@ from prosemirror.utils import JSON, Attrs, JSONDict -def default_attrs(attrs: "Attributes") -> Optional[Attrs]: +def default_attrs(attrs: "Attributes") -> Attrs | None: defaults = {} for attr_name, attr in attrs.items(): if not attr.has_default: @@ -29,7 +27,7 @@ def default_attrs(attrs: "Attributes") -> Optional[Attrs]: return defaults -def compute_attrs(attrs: "Attributes", value: Optional[Attrs]) -> Attrs: +def compute_attrs(attrs: "Attributes", value: Attrs | None) -> Attrs: built = {} for name in attrs: given = None @@ -69,7 +67,7 @@ class NodeType: inline_content: bool - mark_set: Optional[List["MarkType"]] + mark_set: list["MarkType"] | None def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> None: self.name = name @@ -78,7 +76,7 @@ def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> N self.groups = spec["group"].split(" ") if "group" in spec else [] self.attrs = init_attrs(spec.get("attrs")) self.default_attrs = default_attrs(self.attrs) - self._content_match: Optional[ContentMatch] = None + self._content_match: ContentMatch | None = None self.mark_set = None self.inline_content = False self.is_block = not (spec.get("inline") or name == "text") @@ -98,7 +96,7 @@ def is_inline(self) -> bool: return not self.is_block @property - def is_text_block(self) -> bool: # FIXME: name is wrong, should be is_textblock + def is_textblock(self) -> bool: return self.is_block and self.inline_content @property @@ -116,27 +114,25 @@ def whitespace(self) -> Literal["pre", "normal"]: ) def has_required_attrs(self) -> bool: - for n in self.attrs: - if self.attrs[n].is_required: - return True - return False + return any(self.attrs[n].is_required for n in self.attrs) def compatible_content(self, other: "NodeType") -> bool: return self == other or (self.content_match.compatible(other.content_match)) - def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: + def compute_attrs(self, attrs: Attrs | None) -> Attrs: if attrs is None and self.default_attrs is not None: return self.default_attrs return compute_attrs(self.attrs, attrs) def create( self, - attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, List[Node], None] = None, - marks: Optional[List[Mark]] = None, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: if self.is_text: - raise ValueError("NodeType.create cannot construct text nodes") + msg = "NodeType.create cannot construct text nodes" + raise ValueError(msg) return Node( self, self.compute_attrs(attrs), @@ -146,9 +142,9 @@ def create( def create_checked( self, - attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, List[Node], None] = None, - marks: Optional[List[Mark]] = None, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: content = Fragment.from_(content) if not self.valid_content(content): @@ -157,10 +153,10 @@ def create_checked( def create_and_fill( self, - attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, List[Node], None] = None, - marks: Optional[List[Mark]] = None, - ) -> Optional[Node]: + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, + ) -> Node | None: attrs = self.compute_attrs(attrs) frag = Fragment.from_(content) if frag.size: @@ -188,15 +184,15 @@ def valid_content(self, content: Fragment) -> bool: def allows_mark_type(self, mark_type: "MarkType") -> bool: return self.mark_set is None or mark_type in self.mark_set - def allows_marks(self, marks: List[Mark]) -> bool: + def allows_marks(self, marks: list[Mark]) -> bool: if self.mark_set is None: return True return all(self.allows_mark_type(mark.type) for mark in marks) - def allowed_marks(self, marks: List[Mark]) -> List[Mark]: + def allowed_marks(self, marks: list[Mark]) -> list[Mark]: if self.mark_set is None: return marks - copy: Optional[List[Mark]] = None + copy: list[Mark] | None = None for i, mark in enumerate(marks): if not self.allows_mark_type(mark.type): if not copy: @@ -212,20 +208,25 @@ def allowed_marks(self, marks: List[Mark]) -> List[Mark]: @classmethod def compile( - cls, nodes: Dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" - ) -> Dict["Nodes", "NodeType"]: - result: Dict["Nodes", "NodeType"] = {} + cls, + nodes: dict["Nodes", "NodeSpec"], + schema: "Schema[Nodes, Marks]", + ) -> dict["Nodes", "NodeType"]: + result: dict["Nodes", "NodeType"] = {} for name, spec in nodes.items(): result[name] = NodeType(name, schema, spec) top_node = cast(Nodes, schema.spec.get("topNode") or "doc") if not result.get(top_node): - raise ValueError(f"Schema is missing its top node type {top_node}") + msg = f"Schema is missing its top node type {top_node}" + raise ValueError(msg) if not result.get(cast(Nodes, "text")): - raise ValueError("every schema needs a 'text' type") + msg = "every schema needs a 'text' type" + raise ValueError(msg) if result[cast(Nodes, "text")].attrs: - raise ValueError("the text node type should not have attributes") + msg = "the text node type should not have attributes" + raise ValueError(msg) return result def __str__(self) -> str: @@ -235,7 +236,7 @@ def __repr__(self) -> str: return self.__str__() -Attributes: TypeAlias = Dict[str, "Attribute"] +Attributes: TypeAlias = dict[str, "Attribute"] class Attribute: @@ -249,11 +250,15 @@ def is_required(self) -> bool: class MarkType: - excluded: List["MarkType"] - instance: Optional[Mark] + excluded: list["MarkType"] + instance: Mark | None def __init__( - self, name: str, rank: int, schema: "Schema[Any, Any]", spec: "MarkSpec" + self, + name: str, + rank: int, + schema: "Schema[Any, Any]", + spec: "MarkSpec", ) -> None: self.name = name self.schema = schema @@ -268,7 +273,7 @@ def __init__( def create( self, - attrs: Optional[Attrs] = None, + attrs: Attrs | None = None, ) -> Mark: if not attrs and self.instance: return self.instance @@ -276,19 +281,19 @@ def create( @classmethod def compile( - cls, marks: Dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" - ) -> Dict["Marks", "MarkType"]: + cls, + marks: dict["Marks", "MarkSpec"], + schema: "Schema[Nodes, Marks]", + ) -> dict["Marks", "MarkType"]: result = {} - rank = 0 - for name, spec in marks.items(): + for rank, (name, spec) in enumerate(marks.items()): result[name] = MarkType(name, rank, schema, spec) - rank += 1 return result - def remove_from_set(self, set_: List["Mark"]) -> List["Mark"]: + def remove_from_set(self, set_: list["Mark"]) -> list["Mark"]: return [item for item in set_ if item.type != self] - def is_in_set(self, set: List[Mark]) -> Optional[Mark]: + def is_in_set(self, set: list[Mark]) -> Mark | None: return next((item for item in set if item.type == self), None) def excludes(self, other: "MarkType") -> bool: @@ -311,13 +316,13 @@ class SchemaSpec(TypedDict, Generic[Nodes, Marks]): # determines which [parse rules](#model.NodeSpec.parseDOM) take # precedence by default, and which nodes come first in a given # [group](#model.NodeSpec.group). - nodes: Dict[Nodes, "NodeSpec"] + nodes: dict[Nodes, "NodeSpec"] # The mark types that exist in this schema. The order in which they # are provided determines the order in which [mark # sets](#model.Mark.addToSet) are sorted and in which [parse # rules](#model.MarkSpec.parseDOM) are tried. - marks: NotRequired[Dict[Marks, "MarkSpec"]] + marks: NotRequired[dict[Marks, "MarkSpec"]] # The name of the default top-level node for the schema. Defaults # to `"doc"`. @@ -344,12 +349,12 @@ class NodeSpec(TypedDict, total=False): defining: bool isolating: bool toDOM: Callable[[Node], Any] # FIXME: add types - parseDOM: List[Dict[str, Any]] # FIXME: add types + parseDOM: list[dict[str, Any]] # FIXME: add types toDebugString: Callable[[Node], str] leafText: Callable[[Node], str] -AttributeSpecs: TypeAlias = Dict[str, "AttributeSpec"] +AttributeSpecs: TypeAlias = dict[str, "AttributeSpec"] class MarkSpec(TypedDict, total=False): @@ -359,7 +364,7 @@ class MarkSpec(TypedDict, total=False): group: str spanning: bool toDOM: Callable[[Mark, bool], Any] # FIXME: add types - parseDOM: List[Dict[str, Any]] # FIXME: add types + parseDOM: list[dict[str, Any]] # FIXME: add types class AttributeSpec(TypedDict, total=False): @@ -369,9 +374,9 @@ class AttributeSpec(TypedDict, total=False): class Schema(Generic[Nodes, Marks]): spec: SchemaSpec[Nodes, Marks] - nodes: Dict[Nodes, "NodeType"] + nodes: dict[Nodes, "NodeType"] - marks: Dict[Marks, "MarkType"] + marks: dict[Marks, "MarkType"] def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: self.spec = spec @@ -380,13 +385,15 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: content_expr_cache = {} for prop in self.nodes: if prop in self.marks: - raise ValueError(f"{prop} can not be both a node and a mark") + msg = f"{prop} can not be both a node and a mark" + raise ValueError(msg) type = self.nodes[prop] content_expr = type.spec.get("content", "") mark_expr = type.spec.get("marks") if content_expr not in content_expr_cache: content_expr_cache[content_expr] = ContentMatch.parse( - content_expr, cast(Dict[str, "NodeType"], self.nodes) + content_expr, + cast(dict[str, "NodeType"], self.nodes), ) type.content_match = content_expr_cache[content_expr] @@ -408,40 +415,45 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: ) self.top_node_type = self.nodes[cast(Nodes, self.spec.get("topNode") or "doc")] - self.cached: Dict[str, Any] = {} + self.cached: dict[str, Any] = {} self.cached["wrappings"] = {} def node( self, - type: Union[str, NodeType], - attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, List[Node], None] = None, - marks: Optional[List[Mark]] = None, + type: str | NodeType, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: if isinstance(type, str): type = self.node_type(type) elif not isinstance(type, NodeType): - raise ValueError(f"Invalid node type: {type}") + msg = f"Invalid node type: {type}" + raise ValueError(msg) elif type.schema != self: - raise ValueError(f"Node type from different schema used ({type.name})") + msg = f"Node type from different schema used ({type.name})" + raise ValueError(msg) return type.create_checked(attrs, content, marks) - def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: + def text(self, text: str, marks: list[Mark] | None = None) -> TextNode: type = self.nodes[cast(Nodes, "text")] return TextNode( - type, cast(Attrs, type.default_attrs), text, Mark.set_from(marks) + type, + cast(Attrs, type.default_attrs), + text, + Mark.set_from(marks), ) def mark( self, - type: Union[str, MarkType], - attrs: Optional[Attrs] = None, + type: str | MarkType, + attrs: Attrs | None = None, ) -> Mark: if isinstance(type, str): type = self.marks[cast(Marks, type)] return type.create(attrs) - def node_from_json(self, json_data: JSONDict) -> Union[Node, TextNode]: + def node_from_json(self, json_data: JSONDict) -> Node | TextNode: return Node.from_json(self, json_data) def mark_from_json( @@ -453,11 +465,12 @@ def mark_from_json( def node_type(self, name: str) -> NodeType: found = self.nodes.get(cast(Nodes, name)) if not found: - raise ValueError(f"Unknown node type: {name}") + msg = f"Unknown node type: {name}" + raise ValueError(msg) return found -def gather_marks(schema: Schema[Any, Any], marks: List[str]) -> List[MarkType]: +def gather_marks(schema: Schema[Any, Any], marks: list[str]) -> list[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) @@ -472,5 +485,6 @@ def gather_marks(schema: Schema[Any, Any], marks: List[str]) -> List[MarkType]: ok = mark found.append(mark) if not ok: - raise SyntaxError(f"unknow mark type: '{mark}'") + msg = f"unknow mark type: '{mark}'" + raise SyntaxError(msg) return found diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 2c4ff43..fa81861 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -1,13 +1,7 @@ import html +from collections.abc import Callable, Mapping, Sequence from typing import ( Any, - Callable, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, Union, cast, ) @@ -21,7 +15,7 @@ class DocumentFragment: - def __init__(self, children: List[HTMLNode]) -> None: + def __init__(self, children: list[HTMLNode]) -> None: self.children = children def __str__(self) -> str: @@ -49,7 +43,10 @@ def __str__(self) -> str: class Element(DocumentFragment): def __init__( - self, name: str, attrs: Dict[str, str], children: List[HTMLNode] + self, + name: str, + attrs: dict[str, str], + children: list[HTMLNode], ) -> None: self.name = name self.attrs = attrs @@ -65,25 +62,27 @@ def __str__(self) -> str: return f"<{open_tag_str}>{children_str}{self.name}>" -HTMLOutputSpec = Union[str, Sequence[Any], Element] +HTMLOutputSpec = str | Sequence[Any] | Element class DOMSerializer: def __init__( self, - nodes: Dict[str, Callable[[Node], HTMLOutputSpec]], - marks: Dict[str, Callable[[Mark, bool], HTMLOutputSpec]], + nodes: dict[str, Callable[[Node], HTMLOutputSpec]], + marks: dict[str, Callable[[Mark, bool], HTMLOutputSpec]], ) -> None: self.nodes = nodes self.marks = marks def serialize_fragment( - self, fragment: Fragment, target: Union[Element, DocumentFragment, None] = None + self, + fragment: Fragment, + target: Element | DocumentFragment | None = None, ) -> DocumentFragment: tgt: DocumentFragment = target or DocumentFragment(children=[]) top = tgt - active: Optional[List[Tuple[Mark, DocumentFragment]]] = None + active: list[tuple[Mark, DocumentFragment]] | None = None def each(node: Node, offset: int, index: int) -> None: nonlocal top, active @@ -124,7 +123,8 @@ def serialize_node_inner(self, node: Node) -> HTMLNode: dom, content_dom = type(self).render_spec(self.nodes[node.type.name](node)) if content_dom: if node.is_leaf: - raise Exception("Content hole not allowed in a leaf node spec") + msg = "Content hole not allowed in a leaf node spec" + raise Exception(msg) self.serialize_fragment(node.content, content_dom) return dom @@ -139,25 +139,26 @@ def serialize_node(self, node: Node) -> HTMLNode: return dom def serialize_mark( - self, mark: Mark, inline: bool - ) -> Optional[Tuple[HTMLNode, Optional[Element]]]: + self, + mark: Mark, + inline: bool, + ) -> tuple[HTMLNode, Element | None] | None: to_dom = self.marks.get(mark.type.name) if to_dom: return type(self).render_spec(to_dom(mark, inline)) return None @classmethod - def render_spec( - cls, structure: HTMLOutputSpec - ) -> Tuple[HTMLNode, Optional[Element]]: + def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | None]: if isinstance(structure, str): return html.escape(structure), None if isinstance(structure, Element): return structure, None tag_name = structure[0] if " " in tag_name[1:]: - raise NotImplementedError("XML namespaces are not supported") - content_dom: Optional[Element] = None + msg = "XML namespaces are not supported" + raise NotImplementedError(msg) + content_dom: Element | None = None dom = Element(name=tag_name, attrs={}, children=[]) attrs = structure[1] if len(structure) > 1 else None start = 1 @@ -167,21 +168,22 @@ def render_spec( if value is None: continue if " " in name[1:]: - raise NotImplementedError("XML namespaces are not supported") + msg = "XML namespaces are not supported" + raise NotImplementedError(msg) dom.attrs[name] = value for i in range(start, len(structure)): child = structure[i] if child == 0: if i < len(structure) - 1 or i > start: - raise Exception( - "Content hole must be the only child of its parent node" - ) + msg = "Content hole must be the only child of its parent node" + raise Exception(msg) return dom, dom inner, inner_content = cls.render_spec(child) dom.children.append(inner) if inner_content: if content_dom: - raise Exception("Multiple content holes") + msg = "Multiple content holes" + raise Exception(msg) content_dom = inner_content return dom, content_dom @@ -191,8 +193,9 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMSerializer": @classmethod def nodes_from_schema( - cls, schema: Schema[str, Any] - ) -> Dict[str, Callable[["Node"], HTMLOutputSpec]]: + cls, + schema: Schema[str, Any], + ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: result["text"] = lambda node: node.text @@ -200,14 +203,15 @@ def nodes_from_schema( @classmethod def marks_from_schema( - cls, schema: Schema[Any, Any] - ) -> Dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: + cls, + schema: Schema[Any, Any], + ) -> dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: return gather_to_dom(schema.marks) def gather_to_dom( - obj: Mapping[str, Union[NodeType, MarkType]], -) -> Dict[str, Callable[..., Any]]: + obj: Mapping[str, NodeType | MarkType], +) -> dict[str, Callable[..., Any]]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index 2c022a3..d7a35f8 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from prosemirror.model import Schema from prosemirror.model.schema import MarkSpec, NodeSpec @@ -9,7 +9,7 @@ pre_dom = ["pre", ["code", 0]] br_dom = ["br"] -nodes: Dict[str, NodeSpec] = { +nodes: dict[str, NodeSpec] = { "doc": {"content": "block+"}, "paragraph": { "content": "inline*", @@ -66,7 +66,7 @@ "src": dom_.get("src"), "title": dom_.get("title"), }, - } + }, ], "toDOM": lambda node: [ "img", @@ -90,7 +90,7 @@ strong_dom = ["strong", 0] code_dom = ["code", 0] -marks: Dict[str, MarkSpec] = { +marks: dict[str, MarkSpec] = { "link": { "attrs": {"href": {}, "title": {"default": None}}, "inclusive": False, diff --git a/prosemirror/schema/list/schema_list.py b/prosemirror/schema/list/schema_list.py index 8d3ee87..eca4318 100644 --- a/prosemirror/schema/list/schema_list.py +++ b/prosemirror/schema/list/schema_list.py @@ -1,4 +1,4 @@ -from typing import Dict, cast +from typing import cast from prosemirror.model.schema import Nodes, NodeSpec @@ -27,15 +27,19 @@ def add(obj: "NodeSpec", props: "NodeSpec") -> "NodeSpec": def add_list_nodes( - nodes: Dict["Nodes", "NodeSpec"], item_content: str, list_group: str -) -> Dict["Nodes", "NodeSpec"]: + nodes: dict["Nodes", "NodeSpec"], + item_content: str, + list_group: str, +) -> dict["Nodes", "NodeSpec"]: copy = nodes.copy() copy.update({ cast(Nodes, "ordered_list"): add( - orderd_list, NodeSpec(content="list_item+", group=list_group) + orderd_list, + NodeSpec(content="list_item+", group=list_group), ), cast(Nodes, "bullet_list"): add( - bullet_list, NodeSpec(content="list_item+", group=list_group) + bullet_list, + NodeSpec(content="list_item+", group=list_group), ), cast(Nodes, "list_item"): add(list_item, NodeSpec(content=item_content)), }) diff --git a/prosemirror/test_builder/__init__.py b/prosemirror/test_builder/__init__.py index aea745a..ba7e245 100644 --- a/prosemirror/test_builder/__init__.py +++ b/prosemirror/test_builder/__init__.py @@ -14,7 +14,7 @@ "doc": { "content": "block+", "attrs": {"meta": {"default": None}}, - } + }, }) test_schema: Schema[Any, Any] = Schema({ diff --git a/prosemirror/test_builder/build.py b/prosemirror/test_builder/build.py index fa5ecd6..269ab81 100644 --- a/prosemirror/test_builder/build.py +++ b/prosemirror/test_builder/build.py @@ -1,19 +1,21 @@ # type: ignore +import contextlib import re -from typing import Any, Callable, Dict, List, Tuple, Union +from collections.abc import Callable +from typing import Any -from prosemirror.model import Node, Schema -from prosemirror.utils import JSONDict +from prosemirror.model import Node, NodeType, Schema +from prosemirror.utils import Attrs, JSONDict NO_TAG = Node.tag = {} def flatten( schema: Schema[Any, Any], - children: List[Union[Node, JSONDict, str]], + children: list[Node | JSONDict | str], f: Callable[[Node], Node], -) -> Tuple[List[Node], Dict[str, int]]: +) -> tuple[list[Node], dict[str, int]]: result, pos, tag = [], 0, NO_TAG for child in children: @@ -58,44 +60,38 @@ def flatten( return result, tag -def id(x): - return x - - -def block(type, attrs): +def block(type: NodeType, attrs: Attrs | None = None): def result(*args): my_attrs = attrs if ( args and args[0] - and not isinstance(args[0], (str, Node)) + and not isinstance(args[0], str | Node) and not getattr(args[0], "flat", None) and "flat" not in args[0] ): my_attrs.update(args[0]) args = args[1:] - nodes, tag = flatten(type.schema, args, id) + nodes, tag = flatten(type.schema, args, lambda x: x) node = type.create(my_attrs, nodes) if tag != NO_TAG: node.tag = tag return node if type.is_leaf: - try: + with contextlib.suppress(ValueError): result.flat = [type.create(attrs)] - except ValueError: - pass return result -def mark(type, attrs): +def mark(type: NodeType, attrs: Attrs): def result(*args): my_attrs = attrs.copy() if ( args and args[0] - and not isinstance(args[0], (str, Node)) + and not isinstance(args[0], str | Node) and not getattr(args[0], "flat", None) and "flat" not in args[0] ): @@ -114,7 +110,7 @@ def f(n): return result -def builders(schema, names): +def builders(schema: Schema[Any, Any], names): result = {"schema": schema} for name in schema.nodes: result[name] = block(schema.nodes[name], {}) diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index cb6ef01..42c1190 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, cast +from typing import Any, cast from prosemirror.model import Fragment, Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -37,7 +37,7 @@ def invert(self, doc: Node) -> Step: assert node_at_pos is not None return AttrStep(self.pos, self.attr, node_at_pos.attrs[self.attr]) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AttrStep(pos.pos, self.attr, self.value) @@ -50,18 +50,18 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json( - schema: Schema[Any, Any], json_data: Union[JSONDict, str] - ) -> "AttrStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AttrStep": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["pos"], int) or not isinstance( - json_data["attr"], str + json_data["attr"], + str, ): - raise ValueError("Invalid input for AttrStep.from_json") + msg = "Invalid input for AttrStep.from_json" + raise ValueError(msg) return AttrStep(json_data["pos"], json_data["attr"], json_data["value"]) diff --git a/prosemirror/transform/doc_attr_step.py b/prosemirror/transform/doc_attr_step.py index b7fb63b..dc9ae58 100644 --- a/prosemirror/transform/doc_attr_step.py +++ b/prosemirror/transform/doc_attr_step.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, cast +from typing import Any, cast from prosemirror.model import Node, Schema from prosemirror.transform.map import Mappable, StepMap @@ -7,7 +7,7 @@ class DocAttrStep(Step): - def __init__(self, attr: str, value: JSON): + def __init__(self, attr: str, value: JSON) -> None: super().__init__() self.attr = attr self.value = value @@ -26,7 +26,7 @@ def get_map(self) -> StepMap: def invert(self, doc: Node) -> Step: return DocAttrStep(self.attr, doc.attrs[self.attr]) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: return self def to_json(self) -> JSONDict: @@ -39,16 +39,15 @@ def to_json(self) -> JSONDict: return json_data @staticmethod - def from_json( - schema: Schema[Any, Any], json_data: Union[JSONDict, str] - ) -> "DocAttrStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "DocAttrStep": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["attr"], str): - raise ValueError("Invalid input for DocAttrStep.from_json") + msg = "Invalid input for DocAttrStep.from_json" + raise ValueError(msg) return DocAttrStep(json_data["attr"], json_data["value"]) diff --git a/prosemirror/transform/map.py b/prosemirror/transform/map.py index e875bed..9325394 100644 --- a/prosemirror/transform/map.py +++ b/prosemirror/transform/map.py @@ -1,5 +1,6 @@ import abc -from typing import Callable, ClassVar, List, Literal, Optional, Union, overload +from collections.abc import Callable +from typing import ClassVar, Literal, overload lower16 = 0xFFFF factor16 = 2**16 @@ -24,9 +25,7 @@ def recover_offset(value: int) -> int: class MapResult: - def __init__( - self, pos: int, del_info: int = 0, recover: Optional[int] = None - ) -> None: + def __init__(self, pos: int, del_info: int = 0, recover: int | None = None) -> None: self.pos = pos self.del_info = del_info self.recover = recover @@ -67,7 +66,7 @@ def map_result(self, pos: int, assoc: int = 1) -> MapResult: ... class StepMap(Mappable): empty: ClassVar["StepMap"] - def __init__(self, ranges: List[int], inverted: bool = False) -> None: + def __init__(self, ranges: list[int], inverted: bool = False) -> None: # prosemirror-transform overrides the constructor to return the # StepMap.empty singleton when ranges are empty. # It is not easy to do in Python, and the intent of that is to make sure @@ -95,7 +94,7 @@ def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: ... @overload def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... - def _map(self, pos: int, assoc: int, simple: bool) -> Union[MapResult, int]: + def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: diff = 0 old_index = 2 if self.inverted else 1 new_index = 1 if self.inverted else 2 @@ -177,27 +176,30 @@ def __str__(self) -> str: class Mapping(Mappable): def __init__( self, - maps: Optional[List[StepMap]] = None, - mirror: Optional[List[int]] = None, - from_: Optional[int] = None, - to: Optional[int] = None, + maps: list[StepMap] | None = None, + mirror: list[int] | None = None, + from_: int | None = None, + to: int | None = None, ) -> None: self.maps = maps or [] self.from_ = from_ or 0 self.to = len(self.maps) if to is None else to self.mirror = mirror - def slice(self, from_: int = 0, to: Optional[int] = None) -> "Mapping": + def slice(self, from_: int = 0, to: int | None = None) -> "Mapping": if to is None: to = len(self.maps) return Mapping(self.maps, self.mirror, from_, to) def copy(self) -> "Mapping": return Mapping( - self.maps[:], (self.mirror[:] if self.mirror else None), self.from_, self.to + self.maps[:], + (self.mirror[:] if self.mirror else None), + self.from_, + self.to, ) - def append_map(self, map: StepMap, mirrors: Optional[int] = None) -> None: + def append_map(self, map: StepMap, mirrors: int | None = None) -> None: self.maps.append(map) self.to = len(self.maps) if mirrors is not None: @@ -214,7 +216,7 @@ def append_mapping(self, mapping: "Mapping") -> None: (start_size + mirr) if (mirr is not None and mirr < i) else None, ) - def get_mirror(self, n: int) -> Optional[int]: + def get_mirror(self, n: int) -> int | None: if self.mirror: for i in range(len(self.mirror)): if (self.mirror[i]) == n: @@ -258,7 +260,7 @@ def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: ... @overload def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... - def _map(self, pos: int, assoc: int, simple: bool) -> Union[MapResult, int]: + def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: del_info = 0 i = self.from_ diff --git a/prosemirror/transform/mark_step.py b/prosemirror/transform/mark_step.py index 7bb897e..8945ede 100644 --- a/prosemirror/transform/mark_step.py +++ b/prosemirror/transform/mark_step.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional, Union, cast +from collections.abc import Callable +from typing import Any, cast from prosemirror.model import Fragment, Mark, Node, Schema, Slice from prosemirror.transform.map import Mappable @@ -34,7 +35,7 @@ def apply(self, doc: Node) -> StepResult: from__ = doc.resolve(self.from_) parent = from__.node(from__.shared_depth(self.to)) - def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: + def iteratee(node: Node, parent: Node | None, i: int) -> Node: if parent and ( not node.is_atom or not parent.type.allows_mark_type(self.mark.type) ): @@ -48,17 +49,17 @@ def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Optional[Node] = None) -> Step: + def invert(self, doc: Node | None = None) -> Step: return RemoveMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if (from_.deleted and to.deleted) or from_.pos > to.pos: return None return AddMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: Step) -> Optional[Step]: + def merge(self, other: Step) -> Step | None: if ( isinstance(other, AddMarkStep) and other.mark.eq(self.mark) @@ -66,7 +67,9 @@ def merge(self, other: Step) -> Optional[Step]: and self.to >= other.from_ ): return AddMarkStep( - min(self.from_, other.from_), max(self.to, other.to), self.mark + min(self.from_, other.from_), + max(self.to, other.to), + self.mark, ) return None @@ -79,18 +82,18 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json( - schema: Schema[Any, Any], json_data: Union[JSONDict, str] - ) -> "AddMarkStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AddMarkStep": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["from"], int) or not isinstance( - json_data["to"], int + json_data["to"], + int, ): - raise ValueError("Invalid input for AddMarkStep.from_json") + msg = "Invalid input for AddMarkStep.from_json" + raise ValueError(msg) return AddMarkStep( json_data["from"], json_data["to"], @@ -111,7 +114,7 @@ def __init__(self, from_: int, to: int, mark: Mark) -> None: def apply(self, doc: Node) -> StepResult: old_slice = doc.slice(self.from_, self.to) - def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: + def iteratee(node: Node, parent: Node | None, i: int) -> Node: return node.mark(self.mark.remove_from_set(node.marks)) slice = Slice( @@ -121,17 +124,17 @@ def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Optional[Node] = None) -> Step: + def invert(self, doc: Node | None = None) -> Step: return AddMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if (from_.deleted and to.deleted) or (from_.pos > to.pos): return None return RemoveMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: Step) -> Optional[Step]: + def merge(self, other: Step) -> Step | None: if ( isinstance(other, RemoveMarkStep) and (other.mark.eq(self.mark)) @@ -139,7 +142,9 @@ def merge(self, other: Step) -> Optional[Step]: and self.to >= other.from_ ): return RemoveMarkStep( - min(self.from_, other.from_), max(self.to, other.to), self.mark + min(self.from_, other.from_), + max(self.to, other.to), + self.mark, ) return None @@ -152,16 +157,18 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["from"], int) or not isinstance( - json_data["to"], int + json_data["to"], + int, ): - raise ValueError("Invalid input for RemoveMarkStep.from_json") + msg = "Invalid input for RemoveMarkStep.from_json" + raise ValueError(msg) return RemoveMarkStep( json_data["from"], json_data["to"], @@ -201,7 +208,7 @@ def invert(self, doc: Node) -> Step: return AddNodeMarkStep(self.pos, self.mark) return RemoveNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AddNodeMarkStep(pos.pos, self.mark) @@ -213,16 +220,18 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["pos"], int): - raise ValueError("Invalid input for AddNodeMarkStep.from_json") + msg = "Invalid input for AddNodeMarkStep.from_json" + raise ValueError(msg) return AddNodeMarkStep( - json_data["pos"], schema.mark_from_json(cast(JSONDict, json_data["mark"])) + json_data["pos"], + schema.mark_from_json(cast(JSONDict, json_data["mark"])), ) @@ -240,7 +249,9 @@ def apply(self, doc: Node) -> StepResult: if not node: return StepResult.fail("No node at mark step's position") updated = node.type.create( - node.attrs, None, self.mark.remove_from_set(node.marks) + node.attrs, + None, + self.mark.remove_from_set(node.marks), ) return StepResult.from_replace( doc, @@ -255,7 +266,7 @@ def invert(self, doc: Node) -> Step: return self return AddNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> Optional[Step]: + def map(self, mapping: Mappable) -> Step | None: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else RemoveNodeMarkStep(pos.pos, self.mark) @@ -267,16 +278,18 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> Step: if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["pos"], int): - raise ValueError("Invalid input for RemoveNodeMarkStep.from_json") + msg = "Invalid input for RemoveNodeMarkStep.from_json" + raise ValueError(msg) return RemoveNodeMarkStep( - json_data["pos"], schema.mark_from_json(cast(JSONDict, json_data["mark"])) + json_data["pos"], + schema.mark_from_json(cast(JSONDict, json_data["mark"])), ) diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index 931554d..c301651 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import cast from prosemirror.model import ( ContentMatch, @@ -16,9 +16,9 @@ def replace_step( doc: Node, from_: int, - to: Optional[int] = None, - slice: Optional[Slice] = None, -) -> Optional[Step]: + to: int | None = None, + slice: Slice | None = None, +) -> Step | None: if to is None: to = from_ if slice is None: @@ -58,9 +58,9 @@ def __init__( self, slice_depth: int, frontier_depth: int, - parent: Optional[Node], - inject: Optional[Fragment] = None, - wrap: Optional[List[NodeType]] = None, + parent: Node | None, + inject: Fragment | None = None, + wrap: list[NodeType] | None = None, ) -> None: self.slice_depth = slice_depth self.frontier_depth = frontier_depth @@ -91,7 +91,7 @@ def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice) -> None: self.from__ = from__ self.unplaced = slice - self.frontier: List[_FrontierItem] = [] + self.frontier: list[_FrontierItem] = [] for i in range(from__.depth + 1): node = from__.node(i) self.frontier.append( @@ -106,7 +106,7 @@ def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice) -> None: def depth(self) -> int: return len(self.frontier) - 1 - def fit(self) -> Optional[Step]: + def fit(self) -> Step | None: while self.unplaced.size: fit = self.find_fittable() if fit: @@ -118,7 +118,7 @@ def fit(self) -> Optional[Step]: placed_size = self.placed.size - self.depth - self.from__.depth from__ = self.from__ to_ = self.close( - self.to_ if move_inline < 0 else from__.doc.resolve(move_inline) + self.to_ if move_inline < 0 else from__.doc.resolve(move_inline), ) if not to_: return None @@ -147,7 +147,7 @@ def fit(self) -> Optional[Step]: return ReplaceStep(from__.pos, to_.pos, slice) return None - def find_fittable(self) -> Optional[_Fittable]: + def find_fittable(self) -> _Fittable | None: start_depth = self.unplaced.open_start cur = self.unplaced.content open_end = self.unplaced.open_end @@ -162,11 +162,14 @@ def find_fittable(self) -> Optional[_Fittable]: for pass_ in [1, 2]: for slice_depth in range( - start_depth if pass_ == 1 else self.unplaced.open_start, -1, -1 + start_depth if pass_ == 1 else self.unplaced.open_start, + -1, + -1, ): if slice_depth: parent = content_at( - self.unplaced.content, slice_depth - 1 + self.unplaced.content, + slice_depth - 1, ).first_child assert parent fragment = parent.content @@ -178,25 +181,19 @@ def find_fittable(self) -> Optional[_Fittable]: type_ = self.frontier[frontier_depth].type match = self.frontier[frontier_depth].match - _nothing = object() - inject = _nothing - wrap = _nothing - - def _lazy_inject() -> Optional[Fragment]: - nonlocal inject - if inject is _nothing: - inject = match.fill_before(Fragment.from_(first), False) - return cast(Optional[Fragment], inject) - - def _lazy_wrap() -> Optional[List[NodeType]]: - nonlocal wrap - assert first is not None - if wrap is _nothing: - wrap = match.find_wrapping(first.type) - return cast(Optional[List[NodeType]], wrap) + inject = None + wrap = None if pass_ == 1 and ( - (match.match_type(first.type) or _lazy_inject()) + ( + match.match_type(first.type) + or ( + inject := match.fill_before( + Fragment.from_(first), + False, + ) + ) + ) if first else parent and type_.compatible_content(parent.type) ): @@ -204,14 +201,18 @@ def _lazy_wrap() -> Optional[List[NodeType]]: slice_depth, frontier_depth, parent, - inject=_lazy_inject(), + inject=inject, ) - elif pass_ == 2 and first and _lazy_wrap(): + elif ( + pass_ == 2 + and first + and (wrap := match.find_wrapping(first.type)) + ): return _Fittable( slice_depth, frontier_depth, parent, - wrap=_lazy_wrap(), + wrap=wrap, ) if parent and match.match_type(parent.type): break @@ -300,7 +301,7 @@ def place_nodes(self, fittable: _Fittable) -> None: next_.mark(type_.allowed_marks(next_.marks)), open_start if taken == 1 else 0, open_end_count if taken == fragment.child_count else -1, - ) + ), ) to_end = taken == fragment.child_count @@ -328,7 +329,7 @@ def place_nodes(self, fittable: _Fittable) -> None: node = cur.last_child assert node is not None self.frontier.append( - _FrontierItem(node.type, node.content_match_at(node.child_count)) + _FrontierItem(node.type, node.content_match_at(node.child_count)), ) cur = node.content @@ -348,23 +349,27 @@ def place_nodes(self, fittable: _Fittable) -> None: ) def must_move_inline(self) -> int: - if not self.to_.parent.is_text_block: + if not self.to_.parent.is_textblock: return -1 top = self.frontier[self.depth] _nothing = object() level = _nothing - def _lazy_level() -> Optional[_CloseLevel]: + def _lazy_level() -> _CloseLevel | None: nonlocal level if level is _nothing: level = self.find_close_level(self.to_) - return cast(Optional[_CloseLevel], level) + return cast(_CloseLevel | None, level) if ( - not top.type.is_text_block + not top.type.is_textblock or not content_after_fits( - self.to_, self.to_.depth, top.type, top.match, False + self.to_, + self.to_.depth, + top.type, + top.match, + False, ) or ( self.to_.depth == self.depth @@ -383,7 +388,7 @@ def _lazy_level() -> Optional[_CloseLevel]: after += 1 return after - def find_close_level(self, to_: ResolvedPos) -> Optional[_CloseLevel]: + def find_close_level(self, to_: ResolvedPos) -> _CloseLevel | None: for i in range(min(self.depth, to_.depth), -1, -1): match = self.frontier[i].match type_ = self.frontier[i].type @@ -406,7 +411,7 @@ def find_close_level(self, to_: ResolvedPos) -> Optional[_CloseLevel]: ) return None - def close(self, to_: ResolvedPos) -> Optional[ResolvedPos]: + def close(self, to_: ResolvedPos) -> ResolvedPos | None: close = self.find_close_level(to_) if not close: return None @@ -425,15 +430,17 @@ def close(self, to_: ResolvedPos) -> Optional[ResolvedPos]: def open_frontier_node( self, type_: NodeType, - attrs: Optional[Attrs] = None, - content: Optional[Fragment] = None, + attrs: Attrs | None = None, + content: Fragment | None = None, ) -> None: top = self.frontier[self.depth] top_match = top.match.match_type(type_) assert top_match is not None top.match = top_match self.placed = add_to_fragment( - self.placed, self.depth, Fragment.from_(type_.create(attrs, content)) + self.placed, + self.depth, + Fragment.from_(type_.create(attrs, content)), ) self.frontier.append(_FrontierItem(type_, type_.content_match)) @@ -505,7 +512,7 @@ def content_after_fits( type_: NodeType, match: ContentMatch, open_: bool, -) -> Optional[Fragment]: +) -> Fragment | None: node = to_.node(depth) index = to_.index_after(depth) if open_ else to_.index(depth) if index == node.child_count and not type_.compatible_content(node.type): @@ -526,7 +533,7 @@ def close_fragment( depth: int, old_open: int, new_open: int, - parent: Optional[Node], + parent: Node | None, ) -> Fragment: if depth < old_open: first = fragment.first_child @@ -534,7 +541,7 @@ def close_fragment( fragment = fragment.replace_child( 0, first.copy( - close_fragment(first.content, depth + 1, old_open, new_open, first) + close_fragment(first.content, depth + 1, old_open, new_open, first), ), ) if depth > new_open: @@ -546,7 +553,8 @@ def close_fragment( matched_fragment = match.match_fragment(start) assert matched_fragment is not None matched_fragment_fill_before = matched_fragment.fill_before( - Fragment.empty, True + Fragment.empty, + True, ) assert matched_fragment_fill_before is not None fragment = start.append(matched_fragment_fill_before) @@ -557,7 +565,7 @@ def close_fragment( def covered_depths( from__: ResolvedPos, to_: ResolvedPos, -) -> List[int]: +) -> list[int]: result = [] min_depth = min(from__.depth, to_.depth) for d in range(min_depth, -1, -1): diff --git a/prosemirror/transform/replace_step.py b/prosemirror/transform/replace_step.py index d45386b..bbfc3cd 100644 --- a/prosemirror/transform/replace_step.py +++ b/prosemirror/transform/replace_step.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast from prosemirror.model import Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -8,7 +8,11 @@ class ReplaceStep(Step): def __init__( - self, from_: int, to: int, slice: Slice, structure: Optional[bool] = None + self, + from_: int, + to: int, + slice: Slice, + structure: bool | None = None, ) -> None: super().__init__() self.from_ = from_ @@ -26,7 +30,9 @@ def get_map(self) -> StepMap: def invert(self, doc: Node) -> "ReplaceStep": return ReplaceStep( - self.from_, self.from_ + self.slice.size, doc.slice(self.from_, self.to) + self.from_, + self.from_ + self.slice.size, + doc.slice(self.from_, self.to), ) def map(self, mapping: Mappable) -> Optional["ReplaceStep"]: @@ -53,7 +59,10 @@ def merge(self, other: "Step") -> Optional["ReplaceStep"]: other.slice.open_end, ) return ReplaceStep( - self.from_, self.to + (other.to - other.from_), slice, self.structure + self.from_, + self.to + (other.to - other.from_), + slice, + self.structure, ) elif ( other.to == self.from_ @@ -86,22 +95,22 @@ def to_json(self) -> JSONDict: return json_data @staticmethod - def from_json( - schema: Schema[Any, Any], json_data: Union[JSONDict, str] - ) -> "ReplaceStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "ReplaceStep": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not isinstance(json_data["from"], int) or not isinstance( - json_data["to"], int + json_data["to"], + int, ): - raise ValueError("Invlid input for ReplaceStep.from_json") + msg = "Invlid input for ReplaceStep.from_json" + raise ValueError(msg) return ReplaceStep( json_data["from"], json_data["to"], - Slice.from_json(schema, cast(Optional[JSONDict], json_data.get("slice"))), + Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), bool(json_data.get("structure")), ) @@ -118,7 +127,7 @@ def __init__( gap_to: int, slice: Slice, insert: int, - structure: Optional[bool] = None, + structure: bool | None = None, ) -> None: super().__init__() self.from_ = from_ @@ -161,7 +170,8 @@ def invert(self, doc: Node) -> "ReplaceAroundStep": self.from_ + self.insert, self.from_ + self.insert + gap, doc.slice(self.from_, self.to).remove_between( - self.gap_from - self.from_, self.gap_to - self.from_ + self.gap_from - self.from_, + self.gap_to - self.from_, ), self.gap_from - self.from_, self.structure, @@ -175,7 +185,13 @@ def map(self, mapping: Mappable) -> Optional["ReplaceAroundStep"]: if (from_.deleted and to.deleted) or gap_from < from_.pos or gap_to > to.pos: return None return ReplaceAroundStep( - from_.pos, to.pos, gap_from, gap_to, self.slice, self.insert, self.structure + from_.pos, + to.pos, + gap_from, + gap_to, + self.slice, + self.insert, + self.structure, ) def to_json(self) -> JSONDict: @@ -201,7 +217,8 @@ def to_json(self) -> JSONDict: @staticmethod def from_json( - schema: Schema[Any, Any], json_data: Union[JSONDict, str] + schema: Schema[Any, Any], + json_data: JSONDict | str, ) -> "ReplaceAroundStep": if isinstance(json_data, str): import json @@ -215,13 +232,14 @@ def from_json( or not isinstance(json_data["gapTo"], int) or not isinstance(json_data["insert"], int) ): - raise ValueError("Invlid input for ReplaceAroundStep.from_json") + msg = "Invlid input for ReplaceAroundStep.from_json" + raise ValueError(msg) return ReplaceAroundStep( json_data["from"], json_data["to"], json_data["gapFrom"], json_data["gapTo"], - Slice.from_json(schema, cast(Optional[JSONDict], json_data.get("slice"))), + Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), json_data["insert"], bool(json_data.get("structure")), ) diff --git a/prosemirror/transform/step.py b/prosemirror/transform/step.py index fd039d6..ebf1432 100644 --- a/prosemirror/transform/step.py +++ b/prosemirror/transform/step.py @@ -1,12 +1,12 @@ import abc -from typing import Any, Dict, Literal, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Literal, Optional, TypeVar, cast, overload from prosemirror.model import Node, ReplaceError, Schema, Slice from prosemirror.transform.map import Mappable, StepMap from prosemirror.utils import JSONDict # like a registry -STEPS_BY_ID: Dict[str, Type["Step"]] = {} +STEPS_BY_ID: dict[str, type["Step"]] = {} StepSubclass = TypeVar("StepSubclass", bound="Step") @@ -32,23 +32,26 @@ def merge(self, _other: "Step") -> Optional["Step"]: def to_json(self) -> JSONDict: ... @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json json_data = cast(JSONDict, json.loads(json_data)) if not json_data or not json_data.get("stepType"): - raise ValueError("Invalid inpit for Step.from_json") + msg = "Invalid inpit for Step.from_json" + raise ValueError(msg) type = STEPS_BY_ID.get(cast(str, json_data["stepType"])) if not type: - raise ValueError(f'no step type {json_data["stepType"]} defined') + msg = f'no step type {json_data["stepType"]} defined' + raise ValueError(msg) return type.from_json(schema, json_data) -def step_json_id(id: str, step_class: Type[StepSubclass]) -> Type[StepSubclass]: +def step_json_id(id: str, step_class: type[StepSubclass]) -> type[StepSubclass]: if id in STEPS_BY_ID: - raise ValueError(f"Duplicated JSON ID for step type: {id}") + msg = f"Duplicated JSON ID for step type: {id}" + raise ValueError(msg) STEPS_BY_ID[id] = step_class step_class.json_id = id @@ -63,7 +66,7 @@ def __init__(self, doc: Node, failed: Literal[None]) -> None: ... @overload def __init__(self, doc: None, failed: str) -> None: ... - def __init__(self, doc: Optional[Node], failed: Optional[str]) -> None: + def __init__(self, doc: Node | None, failed: str | None) -> None: self.doc = doc self.failed = failed diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index 15d10ca..634f918 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -1,4 +1,5 @@ -from typing import List, Optional, TypedDict, Union +from dataclasses import dataclass +from typing import cast from prosemirror.model import ContentMatch, Node, NodeRange, NodeType, Slice from prosemirror.utils import Attrs @@ -10,7 +11,7 @@ def can_cut(node: Node, start: int, end: int) -> bool: return False -def lift_target(range_: NodeRange) -> Optional[int]: +def lift_target(range_: NodeRange) -> int | None: parent = range_.parent content = parent.content.cut_by_index(range_.start_index, range_.end_index) depth = range_.depth @@ -31,17 +32,18 @@ def lift_target(range_: NodeRange) -> Optional[int]: return None -class NodeTypeWithAttrs(TypedDict): +@dataclass +class NodeTypeWithAttrs: type: NodeType - attrs: Optional[Attrs] + attrs: Attrs | None = None def find_wrapping( range_: NodeRange, node_type: NodeType, - attrs: Optional[Attrs] = None, - inner_range: Optional[NodeRange] = None, -) -> Optional[List[NodeTypeWithAttrs]]: + attrs: Attrs | None = None, + inner_range: NodeRange | None = None, +) -> list[NodeTypeWithAttrs] | None: if inner_range is None: inner_range = range_ @@ -58,7 +60,7 @@ def find_wrapping( return ( [with_attrs(item) for item in around] - + [{"type": node_type, "attrs": attrs}] + + [NodeTypeWithAttrs(type=node_type, attrs=attrs)] + [with_attrs(item) for item in inner] ) @@ -67,9 +69,7 @@ def with_attrs(type: NodeType) -> NodeTypeWithAttrs: return NodeTypeWithAttrs(type=type, attrs=None) -def find_wrapping_outside( - range_: NodeRange, type: NodeType -) -> Optional[List[NodeType]]: +def find_wrapping_outside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -80,7 +80,7 @@ def find_wrapping_outside( return around if parent.can_replace_with(start_index, end_index, outer) else None -def find_wrapping_inside(range_: NodeRange, type: NodeType) -> Optional[List[NodeType]]: +def find_wrapping_inside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -91,7 +91,7 @@ def find_wrapping_inside(range_: NodeRange, type: NodeType) -> Optional[List[Nod return None last_type = inside[-1] if len(inside) else type - inner_match: Optional[ContentMatch] = last_type.content_match + inner_match: ContentMatch | None = last_type.content_match i = start_index while inner_match and i < end_index: @@ -113,42 +113,27 @@ def can_change_type(doc: Node, pos: int, type: NodeType) -> bool: def can_split( doc: Node, pos: int, - depth: Optional[int] = None, - types_after: Optional[List[NodeTypeWithAttrs]] = None, + depth: int | None = None, + types_after: list[NodeTypeWithAttrs] | None = None, ) -> bool: if depth is None: depth = 1 pos_ = doc.resolve(pos) base = pos_.depth - depth - inner_type: Union[NodeTypeWithAttrs, Node, None] = None - - if types_after: - inner_type = types_after[-1] - - if not inner_type: - inner_type = pos_.parent - - if isinstance(inner_type, Node): - if ( - base < 0 - or pos_.parent.type.spec.get("isolating") - or not pos_.parent.can_replace(pos_.index(), pos_.parent.child_count) - or not inner_type.type.valid_content( - pos_.parent.content.cut_by_index(pos_.index(), pos_.parent.child_count) - ) - ): - return False + inner_type: NodeTypeWithAttrs = cast( + NodeTypeWithAttrs, + (types_after and types_after[-1]) or pos_.parent, + ) - elif isinstance(inner_type, dict): - if ( - base < 0 - or pos_.parent.type.spec.get("isolating") - or not pos_.parent.can_replace(pos_.index(), pos_.parent.child_count) - or not inner_type["type"].valid_content( - pos_.parent.content.cut_by_index(pos_.index(), pos_.parent.child_count) - ) - ): - return False + if ( + base < 0 + or pos_.parent.type.spec.get("isolating") + or not pos_.parent.can_replace(pos_.index(), pos_.parent.child_count) + or not inner_type.type.valid_content( + pos_.parent.content.cut_by_index(pos_.index(), pos_.parent.child_count), + ) + ): + return False d = pos_.depth - 1 i = depth - 2 @@ -160,40 +145,33 @@ def can_split( return False rest = node.content.cut_by_index(index, node.child_count) - if types_after and len(types_after) > i: + if types_after and len(types_after) > i + 1: override_child = types_after[i + 1] rest = rest.replace_child( - 0, override_child["type"].create(override_child.get("attrs")) + 0, + override_child.type.create(override_child.attrs), ) - after: Union[NodeTypeWithAttrs, Node, None] = None - if types_after and len(types_after) > i: - after = types_after[i] - if not after: - after = node - - if isinstance(after, dict): - if not node.can_replace(index + 1, node.child_count) or not after[ - "type" - ].valid_content(rest): - return False - - if isinstance(after, Node): - if after != node: - rest = rest.replace_child(0, after.type.create(after.attrs)) - if not node.can_replace( - index + 1, node.child_count - ) or not after.type.valid_content(rest): - return False + after: NodeTypeWithAttrs = cast( + NodeTypeWithAttrs, + (types_after and len(types_after) > i and types_after[i]) or node, + ) + if not node.can_replace( + index + 1, + node.child_count, + ) or not after.type.valid_content(rest): + return False d -= 1 i -= 1 index = pos_.index_after(base) base_type = types_after[0] if types_after else None return pos_.node(base).can_replace_with( - index, index, base_type["type"] if base_type else pos_.node(base + 1).type + index, + index, + base_type.type if base_type else pos_.node(base + 1).type, ) -def can_join(doc: Node, pos: int) -> Optional[bool]: +def can_join(doc: Node, pos: int) -> bool | None: pos_ = doc.resolve(pos) index = pos_.index() return ( @@ -203,13 +181,13 @@ def can_join(doc: Node, pos: int) -> Optional[bool]: ) -def joinable(a: Optional[Node], b: Optional[Node]) -> bool: +def joinable(a: Node | None, b: Node | None) -> bool: if a and b and not a.is_leaf: return a.can_append(b) return False -def join_point(doc: Node, pos: int, dir: int = -1) -> Optional[int]: +def join_point(doc: Node, pos: int, dir: int = -1) -> int | None: pos_ = doc.resolve(pos) for d in range(pos_.depth, -1, -1): before = None @@ -227,7 +205,7 @@ def join_point(doc: Node, pos: int, dir: int = -1) -> Optional[int]: after = pos_.node(d + 1) if ( before - and not before.is_text_block + and not before.is_textblock and joinable(before, after) and pos_.node(d).can_replace(index, index + 1) ): @@ -239,7 +217,7 @@ def join_point(doc: Node, pos: int, dir: int = -1) -> Optional[int]: return None -def insert_point(doc: Node, pos: int, node_type: NodeType) -> Optional[int]: +def insert_point(doc: Node, pos: int, node_type: NodeType) -> int | None: pos_ = doc.resolve(pos) if pos_.parent.can_replace_with(pos_.index(), pos_.index(), node_type): return pos @@ -261,12 +239,12 @@ def insert_point(doc: Node, pos: int, node_type: NodeType) -> Optional[int]: return None -def drop_point(doc: Node, pos: int, slice: Slice) -> Optional[int]: +def drop_point(doc: Node, pos: int, slice: Slice) -> int | None: pos_ = doc.resolve(pos) if not slice.content.size: return pos content = slice.content - for i in range(slice.open_start): + for _i in range(slice.open_start): assert content.first_child is not None content = content.first_child.content pass_ = 1 @@ -286,10 +264,12 @@ def drop_point(doc: Node, pos: int, slice: Slice) -> Optional[int]: else: assert content.first_child is not None wrapping = parent.content_match_at(insert_pos).find_wrapping( - content.first_child.type + content.first_child.type, ) fits = wrapping is not None and parent.can_replace_with( - insert_pos, insert_pos, wrapping[0] + insert_pos, + insert_pos, + wrapping[0], ) if fits: if bias == 0: diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 107f524..11fa977 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, TypedDict, Union +from typing import Optional, TypedDict from prosemirror.model import ( ContentMatch, @@ -34,7 +34,7 @@ from .doc_attr_step import DocAttrStep -def defines_content(type: Union[NodeType, MarkType]) -> Optional[bool]: +def defines_content(type: NodeType | MarkType) -> bool | None: if isinstance(type, NodeType): return type.spec.get("defining") or type.spec.get("definingForContent") return False @@ -57,8 +57,8 @@ class Transform: def __init__(self, doc: Node) -> None: self.doc = doc - self.steps: List[Step] = [] - self.docs: List[Node] = [] + self.steps: list[Step] = [] + self.docs: list[Node] = [] self.mapping = Mapping() @property @@ -90,10 +90,10 @@ def add_step(self, step: Step, doc: Node) -> None: def add_mark(self, from_: int, to: int, mark: Mark) -> "Transform": removed = [] added = [] - removing: Optional[RemoveMarkStep] = None - adding: Optional[AddMarkStep] = None + removing: RemoveMarkStep | None = None + adding: AddMarkStep | None = None - def iteratee(node: Node, pos: int, parent: Optional[Node], i: int) -> None: + def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> None: nonlocal removing nonlocal adding if not node.is_inline: @@ -136,7 +136,7 @@ def remove_mark( self, from_: int, to: int, - mark: Union[Mark, MarkType, None] = None, + mark: Mark | MarkType | None = None, ) -> "Transform": class MatchedTypedDict(TypedDict): style: Mark @@ -144,12 +144,10 @@ class MatchedTypedDict(TypedDict): to: int step: int - matched: List[MatchedTypedDict] = [] + matched: list[MatchedTypedDict] = [] step = 0 - def iteratee( - node: Node, pos: int, parent: Optional[Node], i: int - ) -> Optional[bool]: + def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> bool | None: nonlocal step if not node.is_inline: return None @@ -198,7 +196,7 @@ def clear_incompatible( self, pos: int, parent_type: NodeType, - match: Optional[ContentMatch] = None, + match: ContentMatch | None = None, ) -> "Transform": if match is None: match = parent_type.content_match @@ -229,14 +227,15 @@ def clear_incompatible( slice = Slice( Fragment.from_( parent_type.schema.text( - " ", parent_type.allowed_marks(child.marks) - ) + " ", + parent_type.allowed_marks(child.marks), + ), ), 0, 0, ) repl_steps.append( - ReplaceStep(cur + m.start(), cur + m.end(), slice) + ReplaceStep(cur + m.start(), cur + m.end(), slice), ) m = newline.search(child.text, m.end()) cur = end @@ -252,8 +251,8 @@ def clear_incompatible( def replace( self, from_: int, - to: Optional[int] = None, - slice: Optional[Slice] = None, + to: int | None = None, + slice: Slice | None = None, ) -> "Transform": if to is None: to = from_ @@ -268,7 +267,7 @@ def replace_with( self, from_: int, to: int, - content: Union[Fragment, Node, List[Node]], + content: Fragment | Node | list[Node], ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) @@ -278,7 +277,7 @@ def delete(self, from_: int, to: int) -> "Transform": def insert( self, pos: int, - content: Union[Fragment, Node, List[Node]], + content: Fragment | Node | list[Node], ) -> "Transform": return self.replace_with(pos, pos, content) @@ -331,10 +330,10 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": assert left_node is not None def_ = defines_content(left_node.type) if def_ and not left_node.same_markup( - from__.node(abs(preferred_target) - 1) + from__.node(abs(preferred_target) - 1), ): preferred_depth = d - elif def_ or not left_node.type.is_text_block: + elif def_ or not left_node.type.is_textblock: break d -= 1 @@ -359,7 +358,11 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": to_.after(target_depth) if expand else to, Slice( close_fragment( - slice.content, 0, slice.open_start, open_depth, None + slice.content, + 0, + slice.open_start, + open_depth, + None, ), open_depth, slice.open_end, @@ -405,7 +408,8 @@ def delete_range(self, from_: int, to: int) -> "Transform": if depth > 0 and ( last or from__.node(depth - 1).can_replace( - from__.index(depth - 1), to_.index_after(depth - 1) + from__.index(depth - 1), + to_.index_after(depth - 1), ) ): return self.delete(from__.before(depth), to_.after(depth)) @@ -467,55 +471,70 @@ def lift(self, range_: NodeRange, target: int) -> "Transform": Slice(before.append(after), open_start, open_end), before.size - open_start, True, - ) + ), ) def wrap( - self, range_: NodeRange, wrappers: List[structure.NodeTypeWithAttrs] + self, + range_: NodeRange, + wrappers: list[structure.NodeTypeWithAttrs], ) -> "Transform": content = Fragment.empty i = len(wrappers) - 1 while i >= 0: if content.size: - match = wrappers[i]["type"].content_match.match_fragment(content) + match = wrappers[i].type.content_match.match_fragment(content) if not match or not match.valid_end: - raise TransformError( + msg = ( "Wrapper type given to Transform.wrap does not form valid " "content of its parent wrapper" ) + raise TransformError(msg) content = Fragment.from_( - wrappers[i]["type"].create(wrappers[i].get("attrs"), content) + wrappers[i].type.create(wrappers[i].attrs, content), ) i -= 1 start = range_.start end = range_.end return self.step( ReplaceAroundStep( - start, end, start, end, Slice(content, 0, 0), len(wrappers), True - ) + start, + end, + start, + end, + Slice(content, 0, 0), + len(wrappers), + True, + ), ) def set_block_type( self, from_: int, - to: Optional[int], + to: int | None, type: NodeType, - attrs: Optional[Attrs], + attrs: Attrs | None, ) -> "Transform": if to is None: to = from_ - if not type.is_text_block: - raise ValueError("Type given to set_block_type should be a textblock") + if not type.is_textblock: + msg = "Type given to set_block_type should be a textblock" + raise ValueError(msg) map_from = len(self.steps) def iteratee( - node: "Node", pos: int, parent: Optional["Node"], i: int - ) -> Optional[bool]: + node: "Node", + pos: int, + parent: Optional["Node"], + i: int, + ) -> bool | None: if ( - node.is_text_block + node.is_textblock and not node.has_markup(type, attrs) and structure.can_change_type( - self.doc, self.mapping.slice(map_from).map(pos), type + self.doc, + self.mapping.slice(map_from).map(pos), + type, ) ): self.clear_incompatible(self.mapping.slice(map_from).map(pos, 1), type) @@ -529,11 +548,13 @@ def iteratee( start_m + 1, end_m - 1, Slice( - Fragment.from_(type.create(attrs, None, node.marks)), 0, 0 + Fragment.from_(type.create(attrs, None, node.marks)), + 0, + 0, ), 1, True, - ) + ), ) return False return None @@ -544,20 +565,22 @@ def iteratee( def set_node_markup( self, pos: int, - type: Optional[NodeType], - attrs: Optional[Attrs], - marks: Optional[List[Mark]] = None, + type: NodeType | None, + attrs: Attrs | None, + marks: list[Mark] | None = None, ) -> "Transform": node = self.doc.node_at(pos) if not node: - raise ValueError("No node at given position") + msg = "No node at given position" + raise ValueError(msg) if not type: type = node.type new_node = type.create(attrs, None, marks or node.marks) if node.is_leaf: return self.replace_with(pos, pos + node.node_size, new_node) if not type.valid_content(node.content): - raise ValueError(f"Invalid content for node type {type.name}") + msg = f"Invalid content for node type {type.name}" + raise ValueError(msg) return self.step( ReplaceAroundStep( pos, @@ -567,7 +590,7 @@ def set_node_markup( Slice(Fragment.from_(new_node), 0, 0), 1, True, - ) + ), ) def set_node_attribute(self, pos: int, attr: str, value: JSON) -> "Transform": @@ -579,12 +602,13 @@ def set_doc_attribute(self, attr: str, value: JSON) -> "Transform": def add_node_mark(self, pos: int, mark: Mark) -> "Transform": return self.step(AddNodeMarkStep(pos, mark)) - def remove_node_mark(self, pos: int, mark: Union[Mark, MarkType]) -> "Transform": + def remove_node_mark(self, pos: int, mark: Mark | MarkType) -> "Transform": if isinstance(mark, MarkType): node = self.doc.node_at(pos) if not node: - raise ValueError(f"No node at position {pos}") + msg = f"No node at position {pos}" + raise ValueError(msg) mark_in_set = mark.is_in_set(node.marks) @@ -597,8 +621,8 @@ def remove_node_mark(self, pos: int, mark: Union[Mark, MarkType]) -> "Transform" def split( self, pos: int, - depth: Optional[int] = None, - types_after: Optional[List[structure.NodeTypeWithAttrs]] = None, + depth: int | None = None, + types_after: list[structure.NodeTypeWithAttrs] | None = None, ) -> "Transform": if depth is None: depth = 1 @@ -614,14 +638,14 @@ def split( if types_after and len(types_after) > i: type_after = types_after[i] after = Fragment.from_( - type_after["type"].create(type_after.get("attrs"), after) + type_after.type.create(type_after.attrs, after) if type_after - else pos_.node(d).copy(after) + else pos_.node(d).copy(after), ) d -= 1 i -= 1 return self.step( - ReplaceStep(pos, pos, Slice(before.append(after), depth, depth), True) + ReplaceStep(pos, pos, Slice(before.append(after), depth, depth), True), ) def join(self, pos: int, depth: int = 1) -> "Transform": diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 8dd0c35..0b2b4d5 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,11 +1,10 @@ -from typing import Mapping, Sequence, Union - -from typing_extensions import TypeAlias +from collections.abc import Mapping, Sequence +from typing import TypeAlias JSONDict: TypeAlias = Mapping[str, "JSON"] JSONList: TypeAlias = Sequence["JSON"] -JSON: TypeAlias = Union[JSONDict, JSONList, str, int, float, bool, None] +JSON: TypeAlias = JSONDict | JSONList | str | int | float | bool | None Attrs: TypeAlias = JSONDict diff --git a/pyproject.toml b/pyproject.toml index 2944ef6..544a631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,14 +7,14 @@ name = "prosemirror" version = "0.4.0" description = "Python implementation of core ProseMirror modules for collaborative editing" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" authors = [ { name = "Samuel Cormier-Iijima", email = "sam@fellow.co" }, { name = "Shen Li", email = "dustet@gmail.com" }, ] license = { text = "BSD-3-Clause" } keywords = ["prosemirror", "collaborative", "editing"] -dependencies = ["typing-extensions>=4.4", "lxml>=5.2", "cssselect>=1.2"] +dependencies = ["typing-extensions>=4.1", "lxml>=4.9", "cssselect>=1.2"] [project.optional-dependencies] dev = [ @@ -24,13 +24,34 @@ dev = [ "mypy~=1.9", "pytest~=8.1", "pytest-cov~=5.0", - "ruff~=0.3", + "ruff~=0.4", ] [tool.ruff.lint] -select = ["E", "F", "W", "I", "RUF"] +select = [ + "ANN", + "B", + "COM", + "E", + "EM", + "F", + "I", + "I", + "N", + "PT", + "RSE", + "RUF", + "SIM", + "UP", + "W", +] +ignore = ["COM812"] preview = true +[tool.ruff.lint.per-file-ignores] +"prosemirror/test_builder/**" = ["ANN"] +"tests/**" = ["ANN"] + [tool.ruff.format] preview = true diff --git a/tests/conftest.py b/tests/conftest.py index 032c924..33f7dd0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture +@pytest.fixture() def ist(): def ist(a, b=None, key=None): if key is None: diff --git a/tests/prosemirror_model/tests/test_content.py b/tests/prosemirror_model/tests/test_content.py index a5fa7c0..a7b5b09 100644 --- a/tests/prosemirror_model/tests/test_content.py +++ b/tests/prosemirror_model/tests/test_content.py @@ -30,7 +30,7 @@ def match(expr, types): @pytest.mark.parametrize( - "expr,types,valid", + ("expr", "types", "valid"), [ ("", "", True), ("", "image", False), @@ -100,7 +100,7 @@ def test_match_type(expr, types, valid): @pytest.mark.parametrize( - "expr,before,after,result", + ("expr", "before", "after", "result"), [ ( "paragraph horizontal_rule paragraph", @@ -218,7 +218,7 @@ def test_fill_before(expr, before, after, result): @pytest.mark.parametrize( - "expr,before,mid,after,left,right", + ("expr", "before", "mid", "after", "left", "right"), [ ( "paragraph horizontal_rule paragraph horizontal_rule paragraph", @@ -287,7 +287,7 @@ def test_fill3_before(expr, before, mid, after, left, right): b = False if a: b = content.match_fragment( - before.content.append(a).append(mid.content) + before.content.append(a).append(mid.content), ).fill_before(after.content, True) if left: left = Node.from_json(schema, left) diff --git a/tests/prosemirror_model/tests/test_diff.py b/tests/prosemirror_model/tests/test_diff.py index 7472b4a..ee92523 100644 --- a/tests/prosemirror_model/tests/test_diff.py +++ b/tests/prosemirror_model/tests/test_diff.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize( - "a,b", + ("a", "b"), [ ( doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), @@ -39,7 +39,7 @@ def test_find_diff_start(a, b): @pytest.mark.parametrize( - "a,b", + ("a", "b"), [ ( doc(p("a", em("b")), p("hello"), blockquote(h1("bye"))), diff --git a/tests/prosemirror_model/tests/test_dom.py b/tests/prosemirror_model/tests/test_dom.py index 2033248..c06b529 100644 --- a/tests/prosemirror_model/tests/test_dom.py +++ b/tests/prosemirror_model/tests/test_dom.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize( - "desc,doc,html", + ("desc", "doc", "html"), [ ( "it can represent simple node", @@ -57,7 +57,7 @@ p( "a ", a({"href": "foo"}, "big ", a({"href": "bar"}, "nested"), " link"), - ) + ), ), '
', @@ -65,7 +65,8 @@ ( "it can represent an unordered list", doc( - ul(li(p("one")), li(p("two")), li(p("three", strong("!")))), p("after") + ul(li(p("one")), li(p("two")), li(p("three", strong("!")))), + p("after"), ), "one
two
three" "!
after
", @@ -73,7 +74,8 @@ ( "it can represent an ordered list", doc( - ol(li(p("one")), li(p("two")), li(p("three", strong("!")))), p("after") + ol(li(p("one")), li(p("two")), li(p("three", strong("!")))), + p("after"), ), "one
two
three" "!
after
", @@ -119,7 +121,7 @@ def test_serializer_first(doc, html, desc): @pytest.mark.parametrize( - "desc,serializer,doc,expect", + ("desc", "serializer", "doc", "expect"), [ ( "it can omit a mark", @@ -161,7 +163,7 @@ def test_html_is_escaped(): @pytest.mark.parametrize( - "desc,doc,expect", + ("desc", "doc", "expect"), [ ( "Basic text node", @@ -169,7 +171,10 @@ def test_html_is_escaped(): { "type": "doc", "content": [ - {"type": "paragraph", "content": [{"type": "text", "text": "test"}]} + { + "type": "paragraph", + "content": [{"type": "text", "text": "test"}], + }, ], }, ), @@ -189,7 +194,7 @@ def test_html_is_escaped(): "text": "some bolded text", }, ], - } + }, ], }, ), @@ -247,7 +252,7 @@ def test_html_is_escaped(): "href": "www.google.ca", "title": None, }, - } + }, ], "text": "google", }, @@ -264,7 +269,7 @@ def test_html_is_escaped(): "alt": None, "title": None, }, - } + }, ], }, { @@ -274,7 +279,7 @@ def test_html_is_escaped(): "type": "text", "marks": [{"type": "strong", "attrs": {}}], "text": "Hello", - } + }, ], }, { @@ -314,9 +319,9 @@ def test_html_is_escaped(): { "type": "paragraph", "content": [ - {"type": "text", "text": "Testing the result of this"} + {"type": "text", "text": "Testing the result of this"}, ], - } + }, ], }, ), diff --git a/tests/prosemirror_model/tests/test_mark.py b/tests/prosemirror_model/tests/test_mark.py index 504986b..760a928 100644 --- a/tests/prosemirror_model/tests/test_mark.py +++ b/tests/prosemirror_model/tests/test_mark.py @@ -44,7 +44,7 @@ def link(href, title=None): @pytest.mark.parametrize( - "a,b,res", + ("a", "b", "res"), [ ([em_, strong], [em_, strong], True), ([em_, strong], [em_, code], False), @@ -58,7 +58,7 @@ def test_same_set(a, b, res): @pytest.mark.parametrize( - "a,b,res", + ("a", "b", "res"), [ (link("http://foo"), (link("http://foo")), True), (link("http://foo"), link("http://bar"), False), @@ -116,7 +116,7 @@ def test_remove_form_set(ist): Mark.same_set( link("http://foo", "title").remove_from_set([link("http://foo")]), [link("http://foo")], - ) + ), ) @@ -171,7 +171,7 @@ class TestResolvedPosMarks: ) @pytest.mark.parametrize( - "doc,mark,result", + ("doc", "mark", "result"), [ (doc(p(em("foo"))), em_, True), (doc(p(em("foo"))), strong, False), @@ -185,7 +185,7 @@ def test_is_at(self, doc, mark, result): assert mark.is_in_set(doc.resolve(doc.tag["a"]).marks()) is result @pytest.mark.parametrize( - "a,b", + ("a", "b"), [ (custom_doc.resolve(4).marks(), [custom_strong]), (custom_doc.resolve(3).marks(), [remark1, custom_strong]), diff --git a/tests/prosemirror_model/tests/test_node.py b/tests/prosemirror_model/tests/test_node.py index a1e4410..b0cc8b4 100644 --- a/tests/prosemirror_model/tests/test_node.py +++ b/tests/prosemirror_model/tests/test_node.py @@ -18,7 +18,8 @@ img = out["img"] custom_schema: Schema[ - Literal["doc", "paragraph", "text", "contact", "hard_break"], str + Literal["doc", "paragraph", "text", "contact", "hard_break"], + str, ] = Schema({ "nodes": { "doc": {"content": "paragraph+"}, @@ -72,13 +73,14 @@ def test_respected_by_fragment(self): ) assert str(f) == "