Skip to content

Commit

Permalink
refactor(library/init/lean/): store Syntax.node children in array
Browse files Browse the repository at this point in the history
  • Loading branch information
Kha committed Mar 28, 2019
1 parent 9a6785c commit f44a788
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 69 deletions.
53 changes: 51 additions & 2 deletions library/init/data/array/basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,15 @@ mkEmpty ()
instance : HasEmptyc (Array α) :=
⟨Array.empty⟩

instance : Inhabited (Array α) :=
⟨Array.empty⟩

def isEmpty (a : Array α) : Bool :=
a.size = 0

def singleton (v : α) : Array α :=
mkArray 1 v

@[extern cpp inline "lean::array_index(#2, #3)"]
def index (a : @& Array α) (i : @& Fin a.sz) : α :=
a.data i
Expand All @@ -65,6 +71,9 @@ a.index ⟨i.toNat, h⟩
def get [Inhabited α] (a : @& Array α) (i : @& Nat) : α :=
if h : i < a.sz then a.index ⟨i, h⟩ else default α

def oget (a : @& Array α) (i : @& Nat) : Option α :=
if h : i < a.size then some (a.index ⟨i, h⟩) else none

@[extern cpp inline "lean::array_update(#2, #3, #4)"]
def update (a : Array α) (i : @& Fin a.sz) (v : α) : Array α :=
{ sz := a.sz,
Expand Down Expand Up @@ -111,6 +120,25 @@ iterateAux a f 0 b
@[inline] def foldl (a : Array α) (f : α → β → β) (b : β) : β :=
iterate a b (λ _, f)

section
variables {m : Type v → Type v} [Monad m] [Inhabited β]
local attribute [instance] monadInhabited

-- TODO(Leo): justify termination using wf-rec
@[specialize] partial def miterateAux (a : Array α) (f : Π i : Fin a.sz, α → β → m β) : Nat → β → m β
| i b :=
if h : i < a.sz then
let idx : Fin a.sz := ⟨i, h⟩ in
f idx (a.index idx) b >>= miterateAux (i+1)
else pure b

@[inline] def miterate (a : Array α) (b : β) (f : Π i : Fin a.sz, α → β → m β) : m β :=
miterateAux a f 0 b

@[inline] def mfoldl (a : Array α) (b : β) (f : α → β → m β) : m β :=
miterate a b (λ _, f)
end

@[specialize] private def revIterateAux (a : Array α) (f : Π i : Fin a.sz, α → β → β) : Π (i : Nat), i ≤ a.sz → β → β
| 0 h b := b
| (j+1) h b :=
Expand Down Expand Up @@ -151,11 +179,32 @@ if h : a.size ≤ b.size
then foreach a (λ ⟨i, h'⟩, f (b.index ⟨i, Nat.ltOfLtOfLe h' h⟩))
else foreach b (λ ⟨i, h'⟩, f (a.index ⟨i, Nat.ltTrans h' (Nat.gtOfNotLe h)⟩))

section
variables {m : Type u → Type u} [Monad m]
local attribute [instance] monadInhabited

def mforeachAuxInh (a : Array α) : Inhabited { a' : Array α // a'.sz = a.sz } :=
⟨⟨a, rfl⟩⟩
local attribute [instance] mforeachAuxInh

@[inline] private def mforeachAux (a : Array α) (f : Π i : Fin a.sz, α → m α) : m { a' : Array α // a'.sz = a.sz } :=
miterate a ⟨a, rfl⟩ $ λ i v ⟨a', h⟩, do
let i' : Fin a'.sz := Eq.recOn h.symm i,
x ← f i v,
pure $ ⟨a'.update i' x, (szUpdateEq a' i' x) ▸ h⟩

@[inline] def mforeach (a : Array α) (f : Π i : Fin a.sz, α → m α) : m (Array α) :=
Subtype.val <$> mforeachAux a f

@[inline] def mmap (f : α → m α) (a : Array α) : m (Array α) :=
mforeach a (λ _, f)
end

end Array

def List.toArrayAux {α : Type u} : List α → Array α → Array α
@[inlineIfReduce] def List.toArrayAux {α : Type u} : List α → Array α → Array α
| [] r := r
| (a::as) r := List.toArrayAux as (r.push a)

def List.toArray {α : Type u} (l : List α) : Array α :=
@[inline] def List.toArray {α : Type u} (l : List α) : Array α :=
l.toArrayAux ∅
6 changes: 3 additions & 3 deletions library/init/lean/elaborator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ partial def toPexpr : Syntax → ElaboratorM Expr
let v := view stringLit stx,
pure $ Expr.lit $ Literal.strVal (v.value.getOrElse "NOTAString")
| @choice := do
last::rev ← List.reverse <$> args.mmap (λ a, toPexpr a)
last::rev ← List.reverse <$> args.toList.mmap (λ a, toPexpr a)
| error stx "ill-formed choice",
pure $ Expr.mdata (MData.empty.setNat `choice args.length) $
pure $ Expr.mdata (MData.empty.setNat `choice args.size) $
rev.reverse.foldr Expr.app last
| @structInst := do
let v := view structInst stx,
Expand Down Expand Up @@ -856,7 +856,7 @@ def setOption.elaborate : Elaborator :=
def noKind.elaborate : Elaborator := λ stx, do
some n ← pure stx.asNode
| error stx "noKind.elaborate: unreachable",
n.args.mmap' command.elaborate
n.args.toList.mmap' command.elaborate

def end.elaborate : Elaborator :=
λ cmd, do
Expand Down
4 changes: 2 additions & 2 deletions library/init/lean/expander.lean
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def mkNotationTransformer (nota : NotationMacro) : transformer :=
λ stx, do
some {args := stxArgs, ..} ← pure stx.asNode
| error stx "mkNotationTransformer: unreachable",
flip StateT.run' {NotationTransformerState . stx := stx, stxArgs := stxArgs} $ do
flip StateT.run' {NotationTransformerState . stx := stx, stxArgs := stxArgs.toList} $ do
let spec := nota.nota.spec,
-- Walk through the notation specification, consuming `stx` args and building up substitutions
-- for the notation RHS.
Expand Down Expand Up @@ -471,7 +471,7 @@ def Subtype.transform : transformer :=
def universes.transform : transformer :=
λ stx, do
let v := view «universes» stx,
pure $ Syntax.list $ v.ids.map (λ id, review «universe» {id := id})
pure $ Syntax.list $ List.toArray $ v.ids.map (λ id, review «universe» {id := id})

def sorry.transform : transformer :=
λ stx, pure $ mkApp (globId `sorryAx) [review hole {}, globId `Bool.false]
Expand Down
72 changes: 36 additions & 36 deletions library/init/lean/parser/combinators.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ local notation `Parser` := m Syntax
variables [Monad m] [MonadExcept (Parsec.Message Syntax) m] [MonadParsec Syntax m] [Alternative m]

def node (k : SyntaxNodeKind) (rs : List Parser) : Parser :=
do args ← rs.mfoldl (λ (args : List Syntax) r, do
do args ← rs.mfoldl (λ (args : Array Syntax) r, do
-- on error, append partial Syntax tree to previous successful parses and rethrow
a ← catch r $ λ msg, match args with
-- do not wrap an error in the first argument to uphold the invariant documented at `SyntaxNode`
| [] := throw msg
| _ :=
let args := msg.custom.get :: args in
throw {msg with custom := Syntax.mkNode k args.reverse},
pure (a::args)
) [],
pure $ Syntax.mkNode k args.reverse
a ← catch r $ λ msg, if args.isEmpty then
-- do not wrap an error in the first argument to uphold the invariant documented at `SyntaxNode`
throw msg
else
let args := args.push msg.custom.get in
throw {msg with custom := Syntax.mkNode k args},
pure $ args.push a
) Array.empty,
pure $ Syntax.mkNode k args

@[reducible] def seq : List Parser → Parser := node noKind

Expand All @@ -42,28 +42,29 @@ instance node.view (k) (rs : List Parser) [i : HasView α k] : Parser.HasView α

-- Each Parser Combinator comes equipped with `HasView` and `HasTokens` instances

private def many1Aux (p : Parser) : List Syntax → Nat → Parser
private def many1Aux (p : Parser) : Array Syntax → Nat → Parser
| as 0 := error "unreachable"
| as (n+1) := do
a ← catch p (λ msg, throw {msg with custom :=
-- append `Syntax.missing` to make clear that List is incomplete
Syntax.list (Syntax.missing::msg.custom.get::as).reverse}),
many1Aux (a::as) n <|> pure (Syntax.list (a::as).reverse)
Syntax.list $ (as.push msg.custom.get).push Syntax.missing}),
many1Aux (as.push a) n <|> pure (Syntax.list $ as.push a)

def many1 (r : Parser) : Parser :=
do rem ← remaining, many1Aux r [] (rem+1)
do rem ← remaining, many1Aux r Array.empty (rem+1)

instance many1.tokens (r : Parser) [Parser.HasTokens r] : Parser.HasTokens (many1 r) :=
⟨tokens r⟩

--TODO(Sebastian): should this be an `Array` as well?
instance many1.view (r : Parser) [Parser.HasView α r] : Parser.HasView (List α) (many1 r) :=
{ view := λ stx, match stx.asNode with
| some n := n.args.map (HasView.view r)
| some n := n.args.toList.map (HasView.view r)
| _ := [HasView.view r Syntax.missing],
review := λ as, Syntax.list $ as.map (review r) }
review := λ as, Syntax.list $ List.toArray $ as.map (review r) }

def many (r : Parser) : Parser :=
many1 r <|> pure (Syntax.list [])
many1 r <|> pure (Syntax.list Array.empty)

instance many.tokens (r : Parser) [Parser.HasTokens r] : Parser.HasTokens (many r) :=
⟨tokens r⟩
Expand All @@ -72,25 +73,25 @@ instance many.view (r : Parser) [HasView α r] : Parser.HasView (List α) (many
/- Remark: `many1.view` can handle empty list. -/
{..many1.view r}

private def sepByAux (p : m Syntax) (sep : Parser) (allowTrailingSep : Bool) : Bool → List Syntax → Nat → Parser
private def sepByAux (p : m Syntax) (sep : Parser) (allowTrailingSep : Bool) : Bool → Array Syntax → Nat → Parser
| pOpt as 0 := error "unreachable"
| pOpt as (n+1) := do
let p := if pOpt then some <$> p <|> pure none else some <$> p,
some a ← catch p (λ msg, throw {msg with custom :=
-- append `Syntax.missing` to make clear that List is incomplete
Syntax.list (Syntax.missing::msg.custom.get::as).reverse})
| pure (Syntax.list as.reverse),
Syntax.list $ (as.push msg.custom.get).push Syntax.missing})
| pure (Syntax.list as),
-- I don't want to think about what the output on a failed separator parse should look like
let sep := try sep,
some s ← some <$> sep <|> pure none
| pure (Syntax.list (a::as).reverse),
sepByAux allowTrailingSep (s::a::as) n
| pure (Syntax.list (as.push a)),
sepByAux allowTrailingSep ((as.push s).push a) n

def sepBy (p sep : Parser) (allowTrailingSep := true) : Parser :=
do rem ← remaining, sepByAux p sep allowTrailingSep true [] (rem+1)
do rem ← remaining, sepByAux p sep allowTrailingSep true Array.empty (rem+1)

def sepBy1 (p sep : Parser) (allowTrailingSep := true) : Parser :=
do rem ← remaining, sepByAux p sep allowTrailingSep false [] (rem+1)
do rem ← remaining, sepByAux p sep allowTrailingSep false Array.empty (rem+1)

instance sepBy.tokens (p sep : Parser) (a) [Parser.HasTokens p] [Parser.HasTokens sep] :
Parser.HasTokens (sepBy p sep a) :=
Expand All @@ -112,9 +113,9 @@ private def sepBy.viewAux {α β} (p sep : Parser) [Parser.HasView α p] [Parser
instance sepBy.view {α β} (p sep : Parser) (a) [Parser.HasView α p] [Parser.HasView β sep] :
Parser.HasView (List (SepBy.Elem.View α β)) (sepBy p sep a) :=
{ view := λ stx, match stx.asNode with
| some n := sepBy.viewAux p sep n.args
| some n := sepBy.viewAux p sep n.args.toList
| _ := [⟨view p Syntax.missing, none⟩],
review := λ as, Syntax.list $ as.bind (λ a, match a with
review := λ as, Syntax.list $ List.toArray $ as.bind (λ a, match a with
| ⟨v, some vsep⟩ := [review p v, review sep vsep]
| ⟨v, none⟩ := [review p v]) }

Expand All @@ -132,21 +133,20 @@ def optional (r : Parser) (require := false) : Parser :=
if require then r else
do r ← optional $
-- on error, wrap in "some"
catch r (λ msg, throw {msg with custom := Syntax.list [msg.custom.get]}),
catch r (λ msg, throw {msg with custom := Syntax.list $ Array.singleton msg.custom.get}),
pure $ match r with
| some r := Syntax.list [r]
| none := Syntax.list []
| some r := Syntax.list $ Array.singleton r
| none := Syntax.list $ Array.empty

instance optional.tokens (r : Parser) [Parser.HasTokens r] (req) : Parser.HasTokens (optional r req) :=
⟨tokens r⟩
instance optional.view (r : Parser) [Parser.HasView α r] (req) : Parser.HasView (Option α) (optional r req) :=
{ view := λ stx, match stx.asNode with
| some {args := [], ..} := none
| some {args := [stx], ..} := some $ HasView.view r stx
| some n := HasView.view r <$> n.args.oget 0
| _ := some $ view r Syntax.missing,
review := λ a, match a with
| some a := Syntax.list [review r a]
| none := Syntax.list [] }
| some a := Syntax.list $ Array.singleton $ review r a
| none := Syntax.list $ Array.empty }
instance optional.viewDefault (r : Parser) [Parser.HasView α r] (req) : Parser.HasViewDefault (optional r req) (Option α) none := ⟨⟩

/-- Parse a List `[p1, ..., pn]` of parsers as `p1 <|> ... <|> pn`.
Expand All @@ -169,7 +169,7 @@ def longestMatch (rs : List Parser) : Parser :=
do stxs ← MonadParsec.longestMatch rs,
match stxs with
| [stx] := pure stx
| _ := pure $ Syntax.mkNode choice stxs
| _ := pure $ Syntax.mkNode choice stxs.toArray

instance longestMatch.tokens (rs : List Parser) [Parser.HasTokens rs] : Parser.HasTokens (longestMatch rs) :=
⟨tokens rs⟩
Expand All @@ -179,7 +179,7 @@ def choiceAux : List Parser → Nat → Parser
| [] _ := error "choice: Empty List"
| (r::rs) i :=
do { stx ← r,
pure $ Syntax.mkNode ⟨Name.mkNumeral Name.anonymous i⟩ [stx] }
pure $ Syntax.mkNode ⟨Name.mkNumeral Name.anonymous i⟩ $ Array.singleton stx }
<|> choiceAux rs (i+1)

/-- Parse a List `[p1, ..., pn]` of parsers as `p1 <|> ... <|> pn`.
Expand All @@ -198,7 +198,7 @@ instance choice.tokens (rs : List Parser) [Parser.HasTokens rs] : Parser.HasToke
def longestChoice (rs : List Parser) : Parser :=
do stx::stxs ← MonadParsec.longestMatch $ rs.enum.map $ λ ⟨i, r⟩, do {
stx ← r,
pure $ Syntax.mkNode ⟨Name.mkNumeral Name.anonymous i⟩ [stx]
pure $ Syntax.mkNode ⟨Name.mkNumeral Name.anonymous i⟩ $ Array.singleton stx
} | error "unreachable",
pure stx

Expand Down
2 changes: 1 addition & 1 deletion library/init/lean/parser/module.lean
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def eoi.Parser : moduleParser := do
it ← leftOver,
-- add `eoi` Node for left-over input
let stop := it.toEnd,
pure $ Syntax.mkNode eoi [Syntax.atom ⟨some ⟨⟨stop, stop⟩, stop.offset, ⟨stop, stop⟩⟩, ""⟩]
pure $ Syntax.mkNode eoi [Syntax.atom ⟨some ⟨⟨stop, stop⟩, stop.offset, ⟨stop, stop⟩⟩, ""⟩].toArray

/-- Read command, recovering from errors inside commands (attach partial Syntax tree)
as well as unknown commands (skip input). -/
Expand Down
24 changes: 15 additions & 9 deletions library/init/lean/parser/syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Sebastian Ullrich
-/
prelude
import init.lean.name init.lean.parser.parsec
import init.lean.name init.lean.parser.parsec init.data.array

namespace Lean
namespace Parser
Expand Down Expand Up @@ -53,7 +53,7 @@ Remark: We do create `SyntaxNode`'s with an Empty `args` field (e.g. for represe
-/
structure SyntaxNode (Syntax : Type) :=
(kind : SyntaxNodeKind)
(args : List Syntax)
(args : Array Syntax)
-- Lazily propagated scopes. Scopes are pushed inwards when a Node is destructed via `Syntax.asNode`,
-- until an ident or an atom (in which the scopes vanish) is reached.
-- Scopes are stored in a stack with the most recent Scope at the top.
Expand Down Expand Up @@ -106,15 +106,21 @@ def flipScopes (scopes : macroScopes) : Syntax → Syntax
| (Syntax.rawNode n) := Syntax.rawNode {n with scopes := n.scopes.flip scopes}
| stx := stx

def mkNode (kind : SyntaxNodeKind) (args : List Syntax) :=
def mkNode (kind : SyntaxNodeKind) (args : Array Syntax) :=
Syntax.rawNode { kind := kind, args := args }

/-- Match against `Syntax.rawNode`, propagating lazy macro scopes. -/
def asNode : Syntax → Option (SyntaxNode Syntax)
| (Syntax.rawNode n) := some {n with args := n.args.map (flipScopes n.scopes), scopes := []}
| _ := none

protected def list (args : List Syntax) :=
-- helper function used by the `node!` macro, to make sure its `view` function is branch-less
@[noinline] def args (stx : Syntax) : Array Syntax :=
match stx.asNode with
| some n := n.args
| _ := Array.empty

protected def list (args : Array Syntax) :=
mkNode noKind args

def kind : Syntax → Option SyntaxNodeKind
Expand Down Expand Up @@ -173,7 +179,7 @@ def updateLeading (source : String) : Syntax → Syntax :=
partial def getHeadInfo : Syntax → Option SourceInfo
| (atom a) := a.info
| (ident id) := id.info
| (rawNode n) := n.args.foldr (λ s r, getHeadInfo s <|> r) none
| (rawNode n) := n.args.foldl (λ s r, getHeadInfo s <|> r) none
| _ := none

def getPos (stx : Syntax) : Option Parsec.Position :=
Expand All @@ -190,7 +196,7 @@ partial def reprint : Syntax → Option String
| (ident id@{info := some info, ..}) := pure $ info.leading.toString ++ id.rawVal.toString ++ info.trailing.toString
| (ident id@{info := none, ..}) := pure id.rawVal.toString
| (rawNode n) :=
if n.kind.name = choice.name then match n.args with
if n.kind.name = choice.name then match n.args.toList with
-- should never happen
| [] := failure
-- check that every choice prints the same
Expand All @@ -199,7 +205,7 @@ partial def reprint : Syntax → Option String
ss ← ns.mmap reprint,
guard $ ss.all (= s),
pure s
else String.join <$> n.args.mmap reprint
else String.join <$> n.args.toList.mmap reprint
| missing := ""

protected partial def toFormat : Syntax → Format
Expand All @@ -210,9 +216,9 @@ protected partial def toFormat : Syntax → Format
toFmt "`" ++ toFmt id.val ++ scopes
| stx@(rawNode n) :=
let scopes := match n.scopes with [] := toFmt "" | _ := bracket "{" (joinSep n.scopes.reverse ", ") "}" in
if n.kind.name = `Lean.Parser.noKind then sbracket $ scopes ++ joinSep (n.args.map toFormat) line
if n.kind.name = `Lean.Parser.noKind then sbracket $ scopes ++ joinSep (n.args.toList.map toFormat) line
else let shorterName := n.kind.name.replacePrefix `Lean.Parser Name.anonymous
in paren $ joinSep ((toFmt shorterName ++ scopes) :: n.args.map toFormat) line
in paren $ joinSep ((toFmt shorterName ++ scopes) :: n.args.toList.map toFormat) line
| missing := "<missing>"

instance : HasToFormat Syntax := ⟨Syntax.toFormat⟩
Expand Down
2 changes: 1 addition & 1 deletion library/init/lean/parser/term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def app.Parser : trailingTermParser :=
node! app [fn: getLeading, Arg: Term.Parser maxPrec]

def mkApp (fn : Syntax) (args : List Syntax) : Syntax :=
args.foldl (λ fn Arg, Syntax.mkNode app [fn, Arg]) fn
args.foldl (λ fn Arg, Syntax.mkNode app [fn, Arg].toArray) fn

@[derive Parser.HasTokens Parser.HasView]
def arrow.Parser : trailingTermParser :=
Expand Down
Loading

0 comments on commit f44a788

Please sign in to comment.