From 39ed69a8a9a82abd878ebfa52cb1b0e347eb7195 Mon Sep 17 00:00:00 2001 From: hlaaftana Date: Mon, 7 Feb 2022 14:30:11 +0300 Subject: [PATCH] initial --- .gitignore | 3 + README.md | 3 + src/unions.nim | 2 + src/unions/conditional.nim | 276 +++++++++++++++++++++++++++++++++++ src/unions/flatbranch.nim | 207 ++++++++++++++++++++++++++ src/unions/private/utils.nim | 65 +++++++++ src/unions/unionfield.nim | 104 +++++++++++++ tests/config.nims | 1 + tests/test_conditional.nim | 33 +++++ tests/test_flatbranch.nim | 30 ++++ tests/test_unionfield.nim | 51 +++++++ unions.nimble | 31 ++++ 12 files changed, 806 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 src/unions.nim create mode 100644 src/unions/conditional.nim create mode 100644 src/unions/flatbranch.nim create mode 100644 src/unions/private/utils.nim create mode 100644 src/unions/unionfield.nim create mode 100644 tests/config.nims create mode 100644 tests/test_conditional.nim create mode 100644 tests/test_flatbranch.nim create mode 100644 tests/test_unionfield.nim create mode 100644 unions.nimble diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4ec9bcc --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.exe +*.dll +docs/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..ca2ff8e --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# unions + +Object variants utilities. Due for a cleanup and rename. See tests for current modules diff --git a/src/unions.nim b/src/unions.nim new file mode 100644 index 0000000..a3d824d --- /dev/null +++ b/src/unions.nim @@ -0,0 +1,2 @@ +import unions/[conditional, flatbranch] +export conditional, flatbranch diff --git a/src/unions/conditional.nim b/src/unions/conditional.nim new file mode 100644 index 0000000..c748262 --- /dev/null +++ b/src/unions/conditional.nim @@ -0,0 +1,276 @@ +## object variants generalized to any condition for each possible union value + +runnableExamples: + conditionalUnion: + type Foo = ref object + num: int + case branch: _ # type has to be _, all types without a case branch with _ as discriminator type are kept in the type section + # the name "branch" can be changed or made _ in which case it defaults to "branch" for now + of Odd, num mod 2 == 1: # branch names can also be _ + name: string + of DoubleEven, num mod 4 == 0: + a, b: int + of Even: + a: int # duplicate names are allowed, only the accessors care that they have the same type + + var foo = Foo(num: 1) + foo.name = "abc" + doAssert foo.name == "abc" + doAssert foo.branch == Odd # branch is not an actual variable + foo.num = 2 + foo.resetBranch() # advantage over object variants, named after "branch" + foo.a = 3 + doAssert foo.a == 3 + doAssert foo.branch == Even + foo.num = 16 + doAssert foo.a == 3 + foo.b = 4 + doAssert foo.b == 4 + doAssert foo.branch == DoubleEven + +import macros, strutils, sets, private/utils + +type ConditionalFieldDefect* = object of FieldDefect + +proc patchTypeSection(typesec: NimNode, poststmts: var seq[NimNode]) = + expectKind typesec, nnkTypeSection + var typedefIndex = 0 + template insertType(n: NimNode) = + typesec.insert(typedefIndex, n) + inc typedefIndex + while typedefIndex < typesec.len: + var objectNode = typesec[typedefIndex][2] + while objectNode.kind in {nnkRefTy, nnkPtrTy}: objectNode = objectNode[0] + if objectNode.kind == nnkObjectTy: + let typeName = typesec[typedefIndex][0].realBasename + var fieldNames: seq[string] # use for template + for rec in objectNode[2]: + if rec.kind == nnkIdentDefs: + for i in 0 .. rec.len - 3: fieldNames.add(rec[i].realBasename) + for recI, rec in objectNode[2]: + if rec.kind == nnkRecCase and rec[0][1].eqIdent("_"): + type Branch = ref object + index: int + name: string + cond, fields: NimNode + var + branches: seq[Branch] + defaultBranch: Branch + allBranches: seq[Branch] + branchName: NimNode + branchNameExported: bool + if rec[0][0].kind == nnkPostfix: + branchNameExported = true + rec[0][0] = rec[0][0][1] + branchName = ident(rec[0][0].realBasename) + template exportIfBranchExported(n: NimNode): NimNode = + if branchNameExported: postfix(n, "*") else: n + let used = newTree(nnkPragma, ident"used") + for i in 1 ..< rec.len: + let b = rec[i] + case b.kind + of nnkOfBranch: + if b.len > 3: error("`of` for branches only accepts a name and an optional condition", b) + let name = b[0] + if name.kind notin {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice, nnkAccQuoted}: + error("first argument of `of` must be branch name", name) + let cond = if b.len == 3: b[1] else: nil + if not cond.isNil: + for f in fieldNames: + cond.replaceIdent(f, newDotExpr(ident"self", ident(f))) + let recList = b[^1] + let branch = Branch(name: if name.eqIdent("_"): "" else: $name, cond: cond, fields: recList) + if cond.isNil: + if not defaultBranch.isNil: + error("cannot set multiple default branches" & + (if defaultBranch.name.len != 0: ", original default branch is " & defaultBranch.name + else: ""), b) + defaultBranch = branch + else: + branches.add(branch) + allBranches.add(branch) + of nnkElse: + let branch = Branch(name: "", cond: nil, fields: b[1]) + branches.add(branch) + allBranches.add(branch) + of nnkElifBranch: # not actually possible + let branch = Branch(name: "", cond: b[0], fields: b[1]) + branches.add(branch) + allBranches.add(branch) + else: error("unexpected reccase branch kind " & $b.kind, b) + let kinds = newTree(nnkEnumTy, newEmptyNode()) + for i, a in allBranches: + a.index = i + if a.name.len != 0: + kinds.add(newTree(nnkEnumFieldDef, ident(a.name), newLit(i))) + let realBranchName = if branchName.isNil: "branch" else: $branchName + let capitalizedBranch = capitalizeAscii(realBranchName) + let baseTypeName = typeName & capitalizedBranch + let enumName = ident(baseTypeName & "Kind") + insertType(newTree(nnkTypeDef, enumName, newEmptyNode(), kinds)) + template objtype(b: Branch): NimNode = + ident(baseTypeName & + (if b.name.len == 0: $b.index else: b.name.capitalizeAscii) & "Obj") + template fieldname(b: Branch): NimNode = + ident((if b.name.len == 0: "branch" & $b.index else: b.name.uncapitalizeAscii) & "Obj") + let unionRecList = newNimNode(nnkRecList) + for b in allBranches: + let objt = b.objtype + insertType(newTree(nnkTypeDef, objt, newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), b.fields))) + unionRecList.add(newTree(nnkIdentDefs, b.fieldname, objt, newEmptyNode())) + let unionTypeName = ident(baseTypeName & "Obj") + insertType(newTree(nnkTypeDef, + newTree(nnkPragmaExpr, unionTypeName, newTree(nnkPragma, ident"union")), + newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), unionRecList))) + when defined(js) and false: + insertType(newTree(nnkTypeDef, + newTree(nnkPragmaExpr, unionTypeName, newEmptyNode()), + newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), newTree(nnkRecList)))) + for u in unionRecList: + let name = u[0] + let ty = u[1] + poststmts.add(newProc( + name = name, + params = [ident"untyped", newIdentDefs(ident"self", unionTypeName)], + procType = nnkTemplateDef, + body = newTree(nnkCast, newTree(nnkRefTy, ty), ident"self"))) + let unionFieldName = ident(uncapitalizeAscii(realBranchName) & "Obj") + objectNode[2][recI] = newTree(nnkIdentDefs, unionFieldName, unionTypeName, newEmptyNode()) + when false: + let fieldTemplates = newStmtList() + for f in fieldNames: + fieldTemplates.add(newProc( + name = ident(f), + params = [ident"untyped"], + body = newDotExpr(ident"self", ident(f)), + procType = nnkTemplateDef, + pragmas = used)) + template nameOrEnumConvIndex(b: Branch): NimNode = + if b.name.len == 0: newCall(enumName, newLit(b.index)) else: ident(b.name) + block branchkind: + let ifstmt = newNimNode(nnkIfStmt) + for b in branches: + ifstmt.add(newTree(nnkElifBranch, b.cond.copy, + b.nameOrEnumConvIndex)) + if not defaultBranch.isNil: + ifstmt.add(newTree(nnkElse, defaultBranch.nameOrEnumConvIndex)) + else: + ifstmt.add(newTree(nnkElse, newCall(enumName, newLit(0)))) + poststmts.add( + newProc( + name = ident(realBranchName).exportIfBranchExported, + params = [enumName, newIdentDefs(ident"self", ident(typeName))], + body = ifstmt, + pragmas = used + ) + ) + poststmts.add( + newProc( + name = ident("reset" & capitalizedBranch).exportIfBranchExported, + params = [newEmptyNode(), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName)))], + body = newCall("reset", newDotExpr(ident"self", unionFieldName)), + pragmas = used)) + proc hasField(reclist: NimNode, field: string): bool = + for r in reclist: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + if r[i].realBasename.eqIdent(field): return true + false + var doneFields: HashSet[string] + for bi in 0 ..< branches.len: + let bran = branches[bi] + for r in bran.fields: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + let fieldName = r[i].realBasename + if fieldName notin doneFields: + var otherBranches: seq[Branch] + for bj in bi + 1 ..< branches.len: + if branches[bj].fields.hasField(fieldName): + otherBranches.add(branches[bj]) + template getField(b: Branch): NimNode = + newDotExpr(newDotExpr(newDotExpr(ident"self", unionFieldName), b.fieldname), ident(fieldName)) + let setterValue = genSym(nskParam, "value") + var ifstmt = newTree(nnkIfStmt) + var setter = newTree(nnkIfStmt) + var names: seq[string] + names.add(if bran.name.len == 0: $bran.index else: bran.name) + ifstmt.add(newTree(nnkElifBranch, bran.cond.copy, bran.getField)) + setter.add(newTree(nnkElifBranch, bran.cond.copy, bran.getField.newAssignment(setterValue))) + for ob in otherBranches: + ifstmt.add(newTree(nnkElifBranch, ob.cond.copy, ob.getField)) + setter.add(newTree(nnkElifBranch, ob.cond.copy, ob.getField.newAssignment(setterValue))) + names.add(if ob.name.len == 0: $ob.index else: ob.name) + if not defaultBranch.isNil and defaultBranch.fields.hasField(fieldName): + ifstmt.add(newTree(nnkElse, defaultBranch.getField)) + setter.add(newTree(nnkElse, defaultBranch.getField.newAssignment(setterValue))) + elif otherBranches.len == 0: # unsafe + ifstmt = ifstmt[0][1] + setter = setter[0][1] + else: + let raiser = newTree(nnkElse, + newTree(nnkRaiseStmt, newCall("newException", ident"ConditionalFieldDefect", + newLit("object is not of branch " & names.join(" or ") & " and therefore does not have field `" & fieldName & "`")))) + ifstmt.add(raiser) + setter.add(raiser) + let gettername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(ident(fieldName), "*") else: ident(fieldName) + poststmts.add(newProc( + name = gettername, + params = [r[^2], newIdentDefs(ident"self", ident(typeName))], + body = ifstmt, + procType = nnkProcDef, + pragmas = used + )) + let settername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(newTree(nnkAccQuoted, ident(fieldName), ident"="), "*") else: newTree(nnkAccQuoted, ident(fieldName), ident"=") + poststmts.add(newProc( + name = settername, + params = [newEmptyNode(), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName))), newIdentDefs(setterValue, r[^2])], + body = setter, + procType = nnkProcDef, + pragmas = used + )) + when false: + poststmts.add(newProc( + name = gettername, + params = [newTree(nnkVarTy, r[^2]), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName)))], + body = ifstmt, + pragmas = used + )) + doneFields.incl(fieldName) + if not defaultBranch.isNil: + for r in defaultBranch.fields: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + let fieldName = r[i].realBasename + if fieldName notin doneFields: + let gettername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(ident(fieldName), "*") else: ident(fieldName) + # somewhat unsafe + let body = newDotExpr(newDotExpr(newDotExpr(ident"self", unionFieldName), defaultBranch.fieldname), ident(fieldName)) + poststmts.add(newProc( + name = gettername, + params = [r[^2], newIdentDefs(ident"self", ident(typeName))], + body = body, + procType = nnkProcDef, + pragmas = used + )) + let setterValue = genSym(nskParam, "value") + let setter = body.newAssignment(setterValue) + let settername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(newTree(nnkAccQuoted, ident(fieldName), ident"="), "*") else: newTree(nnkAccQuoted, ident(fieldName), ident"=") + poststmts.add(newProc( + name = settername, + params = [newEmptyNode(), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName))), newIdentDefs(setterValue, r[^2])], + body = setter, + procType = nnkProcDef, + pragmas = used + )) + inc typedefIndex + +macro conditionalUnion*(body) = + result = applyTypeMacro(body, patchTypeSection) diff --git a/src/unions/flatbranch.nim b/src/unions/flatbranch.nim new file mode 100644 index 0000000..09d376c --- /dev/null +++ b/src/unions/flatbranch.nim @@ -0,0 +1,207 @@ +## flatten object variant case branches into their own objects, allows shared field names +## + +runnableExamples: + type FooKind = enum fkA, fkB, fkC + flattenBranches: + type + Foo = ref object + num: int + case kind: FooKind + of fkA: + name: string + of fkB: + a, b: int + else: + a: int + var f = Foo(num: 1, kind: fkA) + f.name = "abc" + doAssert f.kind == fkA + doAssert f.name == "abc" + f = Foo(num: 16, kind: fkB) + doAssert f.kind == fkB + doAssert f.a == 0 + doAssert f.b == 0 + f.a = 5 + doAssert f.a == 5 + f.b = 7 + doAssert f.b == 7 + f = Foo(num: 2, kind: fkC) + doAssert f.kind == fkC + doAssert f.a == 0 + +import macros, strutils, sets, private/utils + +proc patchTypeSection(typesec: NimNode, poststmts: var seq[NimNode]) = + expectKind typesec, nnkTypeSection + var typedefIndex = 0 + template insertType(n: NimNode) = + typesec.insert(typedefIndex, n) + inc typedefIndex + while typedefIndex < typesec.len: + var objectNode = typesec[typedefIndex][2] + while objectNode.kind in {nnkRefTy, nnkPtrTy}: objectNode = objectNode[0] + if objectNode.kind == nnkObjectTy: + let typeName = typesec[typedefIndex][0].realBasename + var fieldNames: seq[string] # use for template + for rec in objectNode[2]: + if rec.kind == nnkIdentDefs: + for i in 0 .. rec.len - 3: fieldNames.add(rec[i].realBasename) + for recI, rec in objectNode[2]: + if rec.kind == nnkRecCase: + type Branch = ref object + conds: seq[NimNode] + fields: NimNode + name: string + var + branches: seq[Branch] + defaultBranch: Branch + allBranches: seq[Branch] + branchName = rec[0][0].realBasename + template objtype(b: Branch): NimNode = + ident(typeName & branchName & b.name & "Obj") + template fieldname(b: Branch): NimNode = + ident(branchName & b.name & "Obj") + let used = newTree(nnkPragma, ident"used") + for i in 1 ..< rec.len: + let b = rec[i] + let recList = b[^1] + case b.kind + of nnkOfBranch: + var name: string + for c in b[0..^2]: + name.add(c.repr.capitalizeAscii) + let branch = Branch(conds: b[0..^2], + fields: copy recList, + name: name) + let objt = branch.objtype + b[^1] = newTree(nnkRecList, newTree(nnkIdentDefs, branch.fieldname, objt, newEmptyNode())) + insertType(newTree(nnkTypeDef, objt, newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), branch.fields))) + branches.add(branch) + allBranches.add(branch) + of nnkElse: + let branch = Branch(conds: @[], + fields: copy recList, + name: "Else") + let objt = branch.objtype + b[^1] = newTree(nnkRecList, newTree(nnkIdentDefs, branch.fieldname, objt, newEmptyNode())) + insertType(newTree(nnkTypeDef, objt, newEmptyNode(), + newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), branch.fields))) + defaultBranch = branch + allBranches.add(branch) + else: error("unexpected reccase branch kind " & $b.kind, b) + #let baseTypeName = typeName & capitalizedBranch + #let unionRecList = newNimNode(nnkRecList) + #let unionTypeName = ident(baseTypeName & "Obj") + #insertType(newTree(nnkTypeDef, + # newTree(nnkPragmaExpr, unionTypeName, newTree(nnkPragma, ident"union")), + # newEmptyNode(), + # newTree(nnkObjectTy, newEmptyNode(), newEmptyNode(), unionRecList))) + #let unionFieldName = ident(uncapitalizeAscii(realBranchName) & "Obj") + #objectNode[2][recI] = newTree(nnkIdentDefs, unionFieldName, unionTypeName, newEmptyNode()) + #template nameOrEnumConvIndex(b: Branch): NimNode = + # if b.name.len == 0: newCall(enumName, newLit(b.index)) else: ident(b.name) + proc hasField(reclist: NimNode, field: string): bool = + for r in reclist: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + if r[i].realBasename.eqIdent(field): return true + false + var doneFields: HashSet[string] + for bi in 0 ..< branches.len: + let bran = branches[bi] + for r in bran.fields: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + let fieldName = r[i].realBasename + if fieldName notin doneFields: + var otherBranches: seq[Branch] + for bj in bi + 1 ..< branches.len: + if branches[bj].fields.hasField(fieldName): + otherBranches.add(branches[bj]) + template getField(b: Branch): NimNode = + newDotExpr(newDotExpr(ident"self", b.fieldname), ident(fieldName)) + let setterValue = genSym(nskParam, "value") + var getter = newTree(nnkCaseStmt, newDotExpr(ident"self", ident(branchName))) + var setter = newTree(nnkCaseStmt, newDotExpr(ident"self", ident(branchName))) + var names: seq[string] + for b in bran.conds: + names.add(b.repr) + getter.add(newTree(nnkOfBranch).add(bran.conds).add(bran.getField)) + setter.add(newTree(nnkOfBranch).add(bran.conds).add(bran.getField.newAssignment(setterValue))) + for ob in otherBranches: + getter.add(newTree(nnkOfBranch).add(ob.conds).add(ob.getField)) + setter.add(newTree(nnkOfBranch).add(ob.conds).add(ob.getField.newAssignment(setterValue))) + for b in ob.conds: + names.add(b.repr) + if not defaultBranch.isNil and defaultBranch.fields.hasField(fieldName): + getter.add(newTree(nnkElse, defaultBranch.getField)) + setter.add(newTree(nnkElse, defaultBranch.getField.newAssignment(setterValue))) + elif otherBranches.len == 0: # unsafe + getter = getter[1][1] + setter = setter[1][1] + else: + let raiser = newTree(nnkElse, + newTree(nnkRaiseStmt, newCall("newException", ident"FieldDefect", + newLit("object is not of branch " & names.join(" or ") & " and therefore does not have field `" & fieldName & "`")))) + getter.add(raiser) + setter.add(raiser) + let gettername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(ident(fieldName), "*") else: ident(fieldName) + poststmts.add(newProc( + name = gettername, + params = [r[^2], newIdentDefs(ident"self", ident(typeName))], + body = getter, + procType = nnkProcDef, + pragmas = used + )) + let settername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(newTree(nnkAccQuoted, ident(fieldName), ident"="), "*") else: newTree(nnkAccQuoted, ident(fieldName), ident"=") + poststmts.add(newProc( + name = settername, + params = [newEmptyNode(), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName))), newIdentDefs(setterValue, r[^2])], + body = setter, + procType = nnkProcDef, + pragmas = used + )) + when false: + poststmts.add(newProc( + name = gettername, + params = [newTree(nnkVarTy, r[^2]), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName)))], + body = ifstmt, + pragmas = used + )) + doneFields.incl(fieldName) + if not defaultBranch.isNil: + for r in defaultBranch.fields: + if r.kind == nnkIdentDefs: + for i in 0 .. r.len - 3: + let fieldName = r[i].realBasename + if fieldName notin doneFields: + let gettername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(ident(fieldName), "*") else: ident(fieldName) + # somewhat unsafe + let body = newDotExpr(newDotExpr(ident"self", defaultBranch.fieldname), ident(fieldName)) + poststmts.add(newProc( + name = gettername, + params = [r[^2], newIdentDefs(ident"self", ident(typeName))], + body = body, + procType = nnkProcDef, + pragmas = used + )) + let setterValue = genSym(nskParam, "value") + let setter = body.newAssignment(setterValue) + let settername = if (r[i].kind == nnkPragmaExpr and r[i][0].kind == nnkPostfix) or r[i].kind == nnkPostfix: + postfix(newTree(nnkAccQuoted, ident(fieldName), ident"="), "*") else: newTree(nnkAccQuoted, ident(fieldName), ident"=") + poststmts.add(newProc( + name = settername, + params = [newEmptyNode(), newIdentDefs(ident"self", newTree(nnkVarTy, ident(typeName))), newIdentDefs(setterValue, r[^2])], + body = setter, + procType = nnkProcDef, + pragmas = used + )) + inc typedefIndex + +macro flattenBranches*(body) = + result = applyTypeMacro(body, patchTypeSection) diff --git a/src/unions/private/utils.nim b/src/unions/private/utils.nim new file mode 100644 index 0000000..28a7409 --- /dev/null +++ b/src/unions/private/utils.nim @@ -0,0 +1,65 @@ +import macros, strutils + +proc uncapitalizeAscii*(s: string): string = + result = s + result[0] = result[0].toLowerAscii + +proc realBasename*(n: NimNode): string = + $(if n.kind in {nnkPostfix, nnkPragmaExpr}: n.basename else: n) + +proc replaceIdent*(n: NimNode, name: string, to: NimNode) = + if n.kind != nnkAccQuoted: + for i in 1 ..< n.len: + if n[i].eqIdent(name) and not (i == 0 and n.kind in nnkCallKinds): n[i] = to + else: replaceIdent(n[i], name, to) + +proc withoutPragma*(prag: NimNode, name: string): (bool, NimNode) = + result[1] = newNimNode(nnkPragma, prag) + for i in 0 ..< prag.len: + if not prag[i].eqIdent(name): + result[1].add(prag[i]) + else: + result[0] = true + if result[1].len == 0: + result[1] = newEmptyNode() + +proc removePragmaFromExpr*(node: NimNode, name: string): bool = + if node.kind == nnkPragmaExpr: + let (res, newNode) = node[1].withoutPragma(name) + if res: + result = res + node[1] = newNode + +proc skipTypeDesc*(node: NimNode): NimNode = + result = node + while result.kind == nnkBracketExpr and result[0].eqIdent"typedesc": + result = result[1] + +proc collectTypeSection(n: NimNode): NimNode = + result = newNimNode(nnkTypeSection, n) + case n.kind + of nnkTypeSection: + result.add(n[0..^1]) + of nnkTypeDef: + result.add(n) + of nnkStmtList: + for b in n: + result.add(collectTypeSection(b)[0..^1]) + else: + error "expected type section", n + +proc applyTypeMacro*(body: NimNode, p: proc (typeSection: NimNode, poststmts: var seq[NimNode])): NimNode = + let inTypeSection = body.kind == nnkTypeDef + let typeSec = collectTypeSection(body) + var poststmts: seq[NimNode] + p(typeSec, poststmts) + if inTypeSection: + if poststmts.len == 0 and typeSec.len == 1: + result = typeSec[0] + else: + result = newTree(nnkTypeDef, genSym(nskType, "_"), newEmptyNode(), + newTree(nnkStmtListType, typeSec).add(poststmts).add(bindSym"void")) + elif poststmts.len == 0: + result = typeSec + else: + result = newStmtList(typeSec).add(poststmts) diff --git a/src/unions/unionfield.nim b/src/unions/unionfield.nim new file mode 100644 index 0000000..51e2b77 --- /dev/null +++ b/src/unions/unionfield.nim @@ -0,0 +1,104 @@ +import macros, private/utils + +proc patchTypeSection(typeSec: NimNode, poststmts: var seq[NimNode]) = + for td in typeSec: + var objectNode = td[^1] + while objectNode.kind in {nnkRefTy, nnkPtrTy}: objectNode = objectNode[0] + if objectNode.kind == nnkObjectTy: + proc doField(idefs: NimNode): NimNode = + if idefs[0].removePragmaFromExpr("union"): + result = newNimNode(nnkRecCase, idefs) + let types = idefs[1] + let name = idefs[0].realBasename + result.add(newIdentDefs(genSym(nskField, name & "Type"), + newTree(nnkBracketExpr, ident"range", infix(newLit(0), "..", newLit(types.len - 1))))) + for i, t in types: + result.add(newTree(nnkOfBranch, newLit(i), + newTree(nnkRecList, newIdentDefs(ident(name & $i), t)))) + else: + result = idefs + + proc eachField(rl: NimNode) = + case rl.kind + of nnkRecList: + for i in 0 ..< rl.len: + if rl[i].kind == nnkIdentDefs: + rl[i] = doField(rl[i]) + else: + eachField(rl[i]) + of nnkRecCase: + rl[0] = doField(rl[0]) + for b in rl[1..^1]: + eachField(b[^1]) + of nnkRecWhen: + for b in rl[1..^1]: + eachField(b[^1]) + else: error "unknown node kind", rl + + eachField(objectNode[^1]) + +macro unionField*(body): untyped = + result = applyTypeMacro(body, patchTypeSection) + +macro withUnionField*(obj: typed, field, body: untyped): untyped = + var impl = obj.getTypeImpl + while impl.kind in {nnkRefTy, nnkPtrTy}: impl = impl[0] + if impl.kind == nnkSym: impl = impl.getTypeImpl + expectKind impl, nnkObjectTy + let fieldName = field.realBasename + let fieldTypeName = fieldName & "Type" + for k in impl[^1]: + if k.kind == nnkRecCase and k[0][0].realBasename.eqIdent(fieldTypeName): + result = newTree(nnkCaseStmt, newDotExpr(obj, ident(fieldTypeName))) + for b in k[1..^1]: + result.add(newTree(nnkOfBranch).add(b[0..^2]).add(newStmtList( + newProc( + name = ident(fieldName), + params = [ident"untyped"], + body = newDotExpr(obj, ident(b[^1][0][0].realBasename)), + procType = nnkTemplateDef + ), + copy body))) + return + error "could not find field " & fieldName, obj + +template getUnionField*(obj: typed, field: untyped, ty: untyped): untyped = + obj.withUnionField(field): + when field is ty: + field + else: + raise newException(FieldError, "field " & astToStr(field) & " was not of type " & $ty) + +macro setUnionField*(obj: typed, field: untyped, value: untyped): untyped = + var objConstr: NimNode + var impl: NimNode + case obj.kind + of nnkObjConstr: + objConstr = obj + impl = obj[0].getTypeImpl + else: + objConstr = obj + impl = obj.getTypeImpl + while impl.kind in {nnkRefTy, nnkPtrTy}: impl = impl[0] + if impl.kind == nnkSym: impl = impl.getTypeImpl + expectKind impl, nnkObjectTy + let fieldName = field.realBasename + let fieldTypeName = fieldName & "Type" + for k in impl[^1]: + if k.kind == nnkRecCase and k[0][0].realBasename.eqIdent(fieldTypeName): + result = newTree(nnkWhenStmt) + for b in k[1..^1]: + let + (lhs1, rhs1) = (ident(fieldTypeName), b[0]) + (lhs2, rhs2) = (ident(b[^1][0][0].realBasename), value) + result.add(newTree(nnkElifBranch, infix(value, "is", b[^1][0][^2]), + if objConstr.kind == nnkObjConstr: + copy(objConstr).add( + newColonExpr(lhs1, rhs1), + newColonExpr(lhs2, rhs2)) + else: + newStmtList( + newAssignment(newDotExpr(objConstr, lhs1), rhs1), + newAssignment(newDotExpr(objConstr, lhs2), rhs2) + ))) + return diff --git a/tests/config.nims b/tests/config.nims new file mode 100644 index 0000000..3bb69f8 --- /dev/null +++ b/tests/config.nims @@ -0,0 +1 @@ +switch("path", "$projectDir/../src") \ No newline at end of file diff --git a/tests/test_conditional.nim b/tests/test_conditional.nim new file mode 100644 index 0000000..02439a5 --- /dev/null +++ b/tests/test_conditional.nim @@ -0,0 +1,33 @@ +import unions/conditional + +conditionalUnion: + type Foo = ref object + num: int + case branch: _ # type has to be _, name can also be _ + of Odd, num mod 2 == 1: # branch names can also be _ + name: string + of DoubleEven, num mod 4 == 0: + a*, b: int + of Even: + a: int + +block: + var f = Foo(num: 1) + f.name = "abc" + doAssert f.name == "abc" + doAssert f.branch == Odd + f.num = 16 + f.resetBranch() + doAssert f.a == 0 + doAssert f.b == 0 + f.a = 5 + doAssert f.a == 5 + f.b = 7 + doAssert f.b == 7 + doAssert f.branch == DoubleEven + f.num = 2 + doAssert f.branch == Even + doAssert f.a == 5 + f.resetBranch() + doAssert f.a == 0 + diff --git a/tests/test_flatbranch.nim b/tests/test_flatbranch.nim new file mode 100644 index 0000000..10da9da --- /dev/null +++ b/tests/test_flatbranch.nim @@ -0,0 +1,30 @@ +import unions/flatbranch + +block: + type FooKind = enum fkA, fkB, fkC + flattenBranches: + type + Foo = ref object + num: int + case kind: FooKind + of fkA: + name: string + of fkB: + a, b: int + else: + a: int + var f = Foo(num: 1, kind: fkA) + f.name = "abc" + doAssert f.kind == fkA + doAssert f.name == "abc" + f = Foo(num: 16, kind: fkB) + doAssert f.kind == fkB + doAssert f.a == 0 + doAssert f.b == 0 + f.a = 5 + doAssert f.a == 5 + f.b = 7 + doAssert f.b == 7 + f = Foo(num: 2, kind: fkC) + doAssert f.kind == fkC + doAssert f.a == 0 diff --git a/tests/test_unionfield.nim b/tests/test_unionfield.nim new file mode 100644 index 0000000..100ffd6 --- /dev/null +++ b/tests/test_unionfield.nim @@ -0,0 +1,51 @@ +import unions/unionfield + +block: + type Foo {.unionField.} = ref object + value {.union.}: (int, bool, string) + + var f1 = Foo().setUnionField(value, 3) + f1.withUnionField(value): + doAssert value is int + when value is int: + doAssert value == 3 + value = 4 + f1.withUnionField(value): + doAssert value is int + when value is int: + doAssert value == 4 + + var s: seq[string] + + proc foo(f: Foo) = + f.withUnionField(value): + s.add($typeof(value)) + + var s2 = @[Foo().setUnionField(value, true), Foo().setUnionField(value, 3), Foo().setUnionField(value, "abc")] + for a in s2: foo(a) + doAssert s == @["bool", "int", "string"] + +block: + type Foo {.unionField.} = ref object + a, b: int + value {.union.}: (int, bool, string) + c, d: int + + var s: seq[string] + + proc foo(f: Foo) = + f.withUnionField(value): + when value is bool: + doAssert value == true + doAssert (f.a, f.b, f.c, f.d) == (1, 0, 0, 0) + elif value is int: + doAssert value == 3 + doAssert (f.a, f.b, f.c, f.d) == (2, 0, 2, 0) + elif value is string: + doAssert value == "abc" + doAssert (f.a, f.b, f.c, f.d) == (0, 0, 0, 0) + s.add($typeof(value)) + + var s2 = @[Foo(a: 1).setUnionField(value, true), Foo(c: 2, a: 2).setUnionField(value, 3), Foo().setUnionField(value, "abc")] + for a in s2: foo(a) + doAssert s == @["bool", "int", "string"] diff --git a/unions.nimble b/unions.nimble new file mode 100644 index 0000000..73fb9ee --- /dev/null +++ b/unions.nimble @@ -0,0 +1,31 @@ +# Package + +version = "0.1.0" +author = "metagn" +description = "object variants utilities" +license = "MIT" +srcDir = "src" + + +# Dependencies + +requires "nim >= 1.6.0" + +when (NimMajor, NimMinor) >= (1, 4): + when (compiles do: import nimbleutils): + import nimbleutils + # https://github.com/metagn/nimbleutils + +task docs, "build docs for all modules": + when declared(buildDocs): + buildDocs(gitUrl = "https://github.com/metagn/unions") + else: + echo "docs task not implemented, need nimbleutils" + +task tests, "run tests for multiple backends and defines": + when declared(runTests): + runTests( + backends = {c, #[js]#}, + ) + else: + echo "tests task not implemented, need nimbleutils"