From f49b0415547acc96719162ecf289260c724e277c Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Fri, 13 Sep 2024 16:16:11 -0600 Subject: [PATCH 01/10] use instantiateAndRegister for binopvv Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 5 +- registration-config.json | 9 + src/BinOp.chpl | 293 ++++++++++------------- src/OperatorMsg.chpl | 468 +++++++++++++++---------------------- src/registry/Commands.chpl | 111 +++++++++ 5 files changed, 435 insertions(+), 451 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index e9fe63dc4c..c5c27060ab 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -520,7 +520,10 @@ def _binop(self, other: pdarray, op: str) -> pdarray: x1, x2, tmp_x1, tmp_x2 = broadcast_if_needed(self, other) except ValueError: raise ValueError(f"shape mismatch {self.shape} {other.shape}") - repMsg = generic_msg(cmd=f"binopvv{x1.ndim}D", args={"op": op, "a": x1, "b": x2}) + repMsg = generic_msg( + cmd=f"binopvv<{self.dtype},{other.dtype},{x1.ndim}>", + args={"op": op, "a": x1, "b": x2} + ) if tmp_x1: del x1 if tmp_x2: diff --git a/registration-config.json b/registration-config.json index 44afced3bf..0975fd8436 100644 --- a/registration-config.json +++ b/registration-config.json @@ -10,6 +10,15 @@ "bool", "bigint" ] + }, + "binop": { + "dtype": [ + "int", + "uint", + "real", + "bool", + "bigint" + ] } } } diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 8419c6b0f6..81983b5110 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -74,48 +74,48 @@ module BinOp :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - proc doBinOpvv(l, r, e, op: string, rname, pn, st) throws { - if e.etype == bool { + proc doBinOpvv(l, r, type etype, op: string, pn, st): MsgTuple throws { + var e = makeDistArray((...l.tupShape), etype); + + const nie = notImplementedError(pn,l.dtype,op,r.dtype); + + if etype == bool { // Since we know that the result type is a boolean, we know // that it either (1) is an operation between bools or (2) uses // a boolean operator (<, <=, etc.) if l.etype == bool && r.etype == bool { select op { when "|" { - e.a = l.a | r.a; + e = l.a | r.a; } when "&" { - e.a = l.a & r.a; + e = l.a & r.a; } when "^" { - e.a = l.a ^ r.a; + e = l.a ^ r.a; } when "==" { - e.a = l.a == r.a; + e = l.a == r.a; } when "!=" { - e.a = l.a != r.a; + e = l.a != r.a; } when "<" { - e.a = l.a:int < r.a:int; + e = l.a:int < r.a:int; } when ">" { - e.a = l.a:int > r.a:int; + e = l.a:int > r.a:int; } when "<=" { - e.a = l.a:int <= r.a:int; + e = l.a:int <= r.a:int; } when ">=" { - e.a = l.a:int >= r.a:int; + e = l.a:int >= r.a:int; } when "+" { - e.a = l.a | r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a | r.a; } + otherwise do return MsgTuple.error(nie); } } // All types support the same binary operations when the resultant @@ -125,203 +125,177 @@ module BinOp if ((l.etype == real && r.etype == bool) || (l.etype == bool && r.etype == real)) { select op { when "<" { - e.a = l.a:real < r.a:real; + e = l.a:real < r.a:real; } when ">" { - e.a = l.a:real > r.a:real; + e = l.a:real > r.a:real; } when "<=" { - e.a = l.a:real <= r.a:real; + e = l.a:real <= r.a:real; } when ">=" { - e.a = l.a:real >= r.a:real; + e = l.a:real >= r.a:real; } when "==" { - e.a = l.a:real == r.a:real; + e = l.a:real == r.a:real; } when "!=" { - e.a = l.a:real != r.a:real; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:real != r.a:real; } + otherwise do return MsgTuple.error(nie); } } else { select op { when "<" { - e.a = l.a < r.a; + e = l.a < r.a; } when ">" { - e.a = l.a > r.a; + e = l.a > r.a; } when "<=" { - e.a = l.a <= r.a; + e = l.a <= r.a; } when ">=" { - e.a = l.a >= r.a; + e = l.a >= r.a; } when "==" { - e.a = l.a == r.a; + e = l.a == r.a; } when "!=" { - e.a = l.a != r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a != r.a; } + otherwise do return MsgTuple.error(nie); } } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } // Since we know that both `l` and `r` are of type `int` and that // the resultant type is not bool (checked in first `if`), we know // what operations are supported based on the resultant type else if (l.etype == int && r.etype == int) || (l.etype == uint && r.etype == uint) { - if e.etype == int || e.etype == uint { + if etype == int || etype == uint { select op { when "+" { - e.a = l.a + r.a; + e = l.a + r.a; } when "-" { - e.a = l.a - r.a; + e = l.a - r.a; } when "*" { - e.a = l.a * r.a; + e = l.a * r.a; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = if ri != 0 then li/ri else 0; } - when "%" { // modulo - ref ea = e.a; + when "%" { // " modulo + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = if ri != 0 then li%ri else 0; } when "<<" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li << ri; } when ">>" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li >> ri; } when "<<<" { - e.a = rotl(l.a, r.a); + e = rotl(l.a, r.a); } when ">>>" { - e.a = rotr(l.a, r.a); + e = rotr(l.a, r.a); } when "&" { - e.a = l.a & r.a; + e = l.a & r.a; } when "|" { - e.a = l.a | r.a; + e = l.a | r.a; } when "^" { - e.a = l.a ^ r.a; + e = l.a ^ r.a; } - when "**" { - if || reduce (r.a<0){ - //instead of error, could we paste the below code but of type float? - var errorMsg = "Attempt to exponentiate base of type Int64 to negative exponent"; - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - e.a= l.a**r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + when "**" { + if || reduce (r.a<0) + then return MsgTuple.error("Attempt to exponentiate base of type Int64 to negative exponent"); + e= l.a**r.a; } + otherwise do return MsgTuple.error(nie); } - } else if e.etype == real { + } else if etype == real { select op { // True division is the only integer type that would result in a // resultant type of `real` when "/" { - e.a = l.a:real / r.a:real; + e = l.a:real / r.a:real; } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + otherwise do return MsgTuple.error(nie); + } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } - else if (e.etype == int && r.etype == uint) || - (e.etype == uint && r.etype == int) { + else if (etype == int && r.etype == uint) || + (etype == uint && r.etype == int) { select op { when ">>" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li >> ri; } when "<<" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li << ri; } when ">>>" { - e.a = rotr(l.a, r.a); + e = rotr(l.a, r.a); } when "<<<" { - e.a = rotl(l.a, r.a); - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = rotl(l.a, r.a); } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if (l.etype == uint && r.etype == int) || (l.etype == int && r.etype == uint) { + + writeln("correct dispatch... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); select op { when "+" { - e.a = l.a:real + r.a:real; + e = l.a:real + r.a:real; } when "-" { - e.a = l.a:real - r.a:real; + e = l.a:real - r.a:real; } when "/" { // truediv - e.a = l.a:real / r.a:real; + e = l.a:real / r.a:real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; var la = l.a:real; var ra = r.a:real; [(ei,li,ri) in zip(ea,la,ra)] ei = floorDivisionHelper(li, ri); } otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + writeln("wtf... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); + return MsgTuple.error(nie); + } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + writeln("returning... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); + return st.insert(new shared SymEntry(e)); } // If either RHS or LHS type is real, the same operations are supported and the // result will always be a `real`, so all 3 of these cases can be shared. @@ -329,169 +303,138 @@ module BinOp || (l.etype == real && r.etype == int)) { select op { when "+" { - e.a = l.a + r.a; + e = l.a + r.a; } when "-" { - e.a = l.a - r.a; + e = l.a - r.a; } when "*" { - e.a = l.a * r.a; + e = l.a * r.a; } when "/" { // truediv - e.a = l.a / r.a; + e = l.a / r.a; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = floorDivisionHelper(li, ri); } when "**" { - e.a= l.a**r.a; + e= l.a**r.a; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = modHelper(li, ri); } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == uint && r.etype == real) || (l.etype == real && r.etype == uint)) { select op { when "+" { - e.a = l.a:real + r.a:real; + e = l.a:real + r.a:real; } when "-" { - e.a = l.a:real - r.a:real; + e = l.a:real - r.a:real; } when "*" { - e.a = l.a:real * r.a:real; + e = l.a:real * r.a:real; } when "/" { // truediv - e.a = l.a:real / r.a:real; + e = l.a:real / r.a:real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = floorDivisionHelper(li, ri); } when "**" { - e.a= l.a:real**r.a:real; + e= l.a:real**r.a:real; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = modHelper(li, ri); } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == int && r.etype == bool) || (l.etype == bool && r.etype == int)) { select op { when "+" { // Since we don't know which of `l` or `r` is the int and which is the `bool`, // we can just cast both to int, which will be a noop for the vector that is // already `int` - e.a = l.a:int + r.a:int; + e = l.a:int + r.a:int; } when "-" { - e.a = l.a:int - r.a:int; + e = l.a:int - r.a:int; } when "*" { - e.a = l.a:int * r.a:int; + e = l.a:int * r.a:int; } when ">>" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li:int >> ri:int; } when "<<" { - ref ea = e.a; + ref ea = e; ref la = l.a; ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] if (0 <= ri && ri < 64) then ei = li:int << ri:int; } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == uint && r.etype == bool) || (l.etype == bool && r.etype == uint)) { select op { when "+" { - e.a = l.a:uint + r.a:uint; + e = l.a:uint + r.a:uint; } when "-" { - e.a = l.a:uint - r.a:uint; + e = l.a:uint - r.a:uint; } when "*" { - e.a = l.a:uint * r.a:uint; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:uint * r.a:uint; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == real && r.etype == bool) || (l.etype == bool && r.etype == real)) { select op { when "+" { - e.a = l.a:real + r.a:real; + e = l.a:real + r.a:real; } when "-" { - e.a = l.a:real - r.a:real; + e = l.a:real - r.a:real; } when "*" { - e.a = l.a:real * r.a:real; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:real * r.a:real; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if (l.etype == bool && r.etype == bool) { select op { when "<<" { - e.a = l.a:int << r.a:int; + e = l.a:int << r.a:int; } when ">>" { - e.a = l.a:int >> r.a:int; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:int >> r.a:int; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); + } else { + return MsgTuple.error(nie); } - var errorMsg = notImplementedError(pn,l.dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } proc doBinOpvs(l, val, e, op: string, dtype, rname, pn, st) throws { diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index fee41a677e..6a124f574e 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -36,22 +36,22 @@ module OperatorMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - @arkouda.registerND - proc binopvvMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + @arkouda.instantiateAndRegister + proc binopvv(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, + type binop_dtype_a, + type binop_dtype_b, + param array_nd: int + ): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string; // response message - - const op = msgArgs.getValueOf("op"); - const aname = msgArgs.getValueOf("a"); - const bname = msgArgs.getValueOf("b"); - var rname = st.nextName(); - var left: borrowed GenSymEntry = getGenericTypedArrayEntry(aname, st); - var right: borrowed GenSymEntry = getGenericTypedArrayEntry(bname, st); - + const l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd), + r = st[msgArgs['b']]: borrowed SymEntry(binop_dtype_b, array_nd), + op = msgArgs['op'].toScalar(string); + omLogger.debug(getModuleName(), getRoutineName(), getLineNumber(), "cmd: %? op: %? left pdarray: %? right pdarray: %?".format( - cmd,op,st.attrib(aname),st.attrib(bname))); + cmd,op,st.attrib(msgArgs['a'].val), + st.attrib(msgArgs['b'].val))); use Set; // This boolOps set is a filter to determine the output type for the operation. @@ -71,307 +71,225 @@ module OperatorMsg realOps.add("/"); realOps.add("//"); - select (left.dtype, right.dtype) { - when (DType.Int64, DType.Int64) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,int, nd); - - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } else if op == "/" { - // True division is the only case in this int, int case - // that results in a `real` symbol table entry. - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvv(l, r, e, op, rname, pn, st); + const rname = st.nextName(); + + if binop_dtype_a == int && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); + } else if op == "/" { + // True division is the only case in this int, int case + // that results in a `real` symbol table entry. + return doBinOpvv(l, r, real, op, pn, st); + } + return doBinOpvv(l, r, int, op, pn, st); + } else if binop_dtype_a == int && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); + } + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Int64, DType.Float64) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,real, nd); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Float64, DType.Int64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.UInt64, DType.Float64) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,real, nd); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Float64, DType.UInt64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + // For cases where a boolean operand is involved, the only + // possible resultant type is `bool` + else if binop_dtype_a == bool && binop_dtype_b == bool { + if (op == "<<") || (op == ">>" ) { + return doBinOpvv(l, r, int, op, pn, st); } - when (DType.Float64, DType.Float64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,real, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, bool, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - // For cases where a boolean operand is involved, the only - // possible resultant type is `bool` - when (DType.Bool, DType.Bool) { - var l = toSymEntry(left,bool, nd); - var r = toSymEntry(right,bool, nd); - if (op == "<<") || (op == ">>" ) { - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, int, op, pn, st); + } + else if binop_dtype_a == int && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Bool, DType.Int64) { - var l = toSymEntry(left,bool, nd); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, int, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Int64, DType.Bool) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Bool, DType.Float64) { - var l = toSymEntry(left,bool, nd); - var r = toSymEntry(right,real, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, real, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Float64, DType.Bool) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.Bool, DType.UInt64) { - var l = toSymEntry(left,bool, nd); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvv(l, r, e, op, rname, pn, st); + return doBinOpvv(l, r, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.UInt64, DType.Bool) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvv(l, r, e, op, rname, pn, st); + if op == "/"{ + return doBinOpvv(l, r, real, op, pn, st); + } else { + return doBinOpvv(l, r, uint, op, pn, st); } - when (DType.UInt64, DType.UInt64) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - if op == "/"{ - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); - } else { - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvv(l, r, e, op, rname, pn, st); - } + } + else if binop_dtype_a == uint && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvv(l, r , bool, op, pn, st); } - when (DType.UInt64, DType.Int64) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r , e, op, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvv(l, r, e, op, rname, pn, st); - } + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpvv(l, r, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpvv(l, r, uint, op, pn, st); } - when (DType.Int64, DType.UInt64) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvv(l, r, e, op, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvv(l, r, e, op, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvv(l, r, e, op, rname, pn, st); - } + } + else if binop_dtype_a == int && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvv(l, r, bool, op, pn, st); } - when (DType.BigInt, DType.BigInt) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpvv(l, r, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpvv(l, r, int, op, pn, st); + } + } + else if binop_dtype_a == bigint && binop_dtype_b == bigint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.BigInt, DType.Int64) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == bigint && binop_dtype_b == int { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.BigInt, DType.UInt64) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == bigint && binop_dtype_b == uint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.BigInt, DType.Bool) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == bigint && binop_dtype_b == bool { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.Int64, DType.BigInt) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == int && binop_dtype_b == bigint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.UInt64, DType.BigInt) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == uint && binop_dtype_b == bigint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } - when (DType.Bool, DType.BigInt) { - var l = toSymEntry(left,bool, nd); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == bool && binop_dtype_b == bigint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); var repMsg = "created %s".format(st.attrib(rname)); return new MsgTuple(repMsg, MsgType.NORMAL); } + // call bigint specific func which returns dist bigint array + var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); + var repMsg = "created %s".format(st.attrib(rname)); + return new MsgTuple(repMsg, MsgType.NORMAL); + } else { + writeln("unreachable???"); + var errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); } - var errorMsg = unrecognizedTypeError(pn, "("+dtype2str(left.dtype)+","+dtype2str(right.dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } /* diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index 931074e4b5..3a6540a7d3 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -19,6 +19,15 @@ param regConfig = """ "bool", "bigint" ] + }, + "binop": { + "dtype": [ + "int", + "uint", + "real", + "bool", + "bigint" + ] } } } @@ -1576,6 +1585,108 @@ proc ark_repeatFlat_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: bor return ManipulationMsg.repeatFlatMsg(cmd, msgArgs, st, array_dtype=bigint, array_nd=1); registerFunction('repeatFlat', ark_repeatFlat_bigint_1, 'ManipulationMsg', 902); +import OperatorMsg; + +proc ark_binopvv_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('binopvv', ark_binopvv_int_int_1, 'OperatorMsg', 40); + +proc ark_binopvv_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvv', ark_binopvv_int_uint_1, 'OperatorMsg', 40); + +proc ark_binopvv_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('binopvv', ark_binopvv_int_real_1, 'OperatorMsg', 40); + +proc ark_binopvv_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvv', ark_binopvv_int_bool_1, 'OperatorMsg', 40); + +proc ark_binopvv_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvv', ark_binopvv_int_bigint_1, 'OperatorMsg', 40); + +proc ark_binopvv_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvv', ark_binopvv_uint_int_1, 'OperatorMsg', 40); + +proc ark_binopvv_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvv', ark_binopvv_uint_uint_1, 'OperatorMsg', 40); + +proc ark_binopvv_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvv', ark_binopvv_uint_real_1, 'OperatorMsg', 40); + +proc ark_binopvv_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvv', ark_binopvv_uint_bool_1, 'OperatorMsg', 40); + +proc ark_binopvv_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvv', ark_binopvv_uint_bigint_1, 'OperatorMsg', 40); + +proc ark_binopvv_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('binopvv', ark_binopvv_real_int_1, 'OperatorMsg', 40); + +proc ark_binopvv_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvv', ark_binopvv_real_uint_1, 'OperatorMsg', 40); + +proc ark_binopvv_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('binopvv', ark_binopvv_real_real_1, 'OperatorMsg', 40); + +proc ark_binopvv_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvv', ark_binopvv_real_bool_1, 'OperatorMsg', 40); + +proc ark_binopvv_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvv', ark_binopvv_real_bigint_1, 'OperatorMsg', 40); + +proc ark_binopvv_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('binopvv', ark_binopvv_bool_int_1, 'OperatorMsg', 40); + +proc ark_binopvv_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvv', ark_binopvv_bool_uint_1, 'OperatorMsg', 40); + +proc ark_binopvv_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('binopvv', ark_binopvv_bool_real_1, 'OperatorMsg', 40); + +proc ark_binopvv_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvv', ark_binopvv_bool_bool_1, 'OperatorMsg', 40); + +proc ark_binopvv_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvv', ark_binopvv_bool_bigint_1, 'OperatorMsg', 40); + +proc ark_binopvv_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvv', ark_binopvv_bigint_int_1, 'OperatorMsg', 40); + +proc ark_binopvv_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvv', ark_binopvv_bigint_uint_1, 'OperatorMsg', 40); + +proc ark_binopvv_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvv', ark_binopvv_bigint_real_1, 'OperatorMsg', 40); + +proc ark_binopvv_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvv', ark_binopvv_bigint_bool_1, 'OperatorMsg', 40); + +proc ark_binopvv_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvv', ark_binopvv_bigint_bigint_1, 'OperatorMsg', 40); + import RandMsg; proc ark_randint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do From 36a014e0aee01346521ace3398153523dd38b421 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Mon, 16 Sep 2024 08:33:36 -0600 Subject: [PATCH 02/10] update multi-dim config file Signed-off-by: Jeremiah Corrado --- .configs/registration-config-multi-dim.json | 9 +++++++++ .configs/registration-config-single-dim.json | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/.configs/registration-config-multi-dim.json b/.configs/registration-config-multi-dim.json index e8e0c481d9..d3ebe75a45 100644 --- a/.configs/registration-config-multi-dim.json +++ b/.configs/registration-config-multi-dim.json @@ -10,6 +10,15 @@ "bool", "bigint" ] + }, + "binop": { + "dtype": [ + "int", + "uint", + "real", + "bool", + "bigint" + ] } } } diff --git a/.configs/registration-config-single-dim.json b/.configs/registration-config-single-dim.json index 44afced3bf..0975fd8436 100644 --- a/.configs/registration-config-single-dim.json +++ b/.configs/registration-config-single-dim.json @@ -10,6 +10,15 @@ "bool", "bigint" ] + }, + "binop": { + "dtype": [ + "int", + "uint", + "real", + "bool", + "bigint" + ] } } } From 59076a143e352ce27de62242f9c48b0bce8b1e0c Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Mon, 16 Sep 2024 14:35:58 -0600 Subject: [PATCH 03/10] continue swapping registerND for instantiateAndRegister in OperatorMsg Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 4 +- src/BinOp.chpl | 276 ++++++++------------ src/OperatorMsg.chpl | 513 ++++++++++--------------------------- src/registry/Commands.chpl | 100 ++++++++ 4 files changed, 346 insertions(+), 547 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index c5c27060ab..31b8a979eb 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -540,8 +540,8 @@ def _binop(self, other: pdarray, op: str) -> pdarray: if dt not in DTypes: raise TypeError(f"Unhandled scalar type: {other} ({type(other)})") repMsg = generic_msg( - cmd=f"binopvs{self.ndim}D", - args={"op": op, "a": self, "dtype": dt, "value": other}, + cmd=f"binopvsMsg<{self.dtype},{dt},{self.ndim}>", + args={"op": op, "a": self, "value": other}, ) return create_pdarray(repMsg) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 81983b5110..a93a08ca3f 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -437,48 +437,48 @@ module BinOp } } - proc doBinOpvs(l, val, e, op: string, dtype, rname, pn, st) throws { - if e.etype == bool { + proc doBinOpvs(l, val, type etype, op: string, pn, st): MsgTuple throws { + var e = makeDistArray((...l.tupShape), etype); + + const nie = notImplementedError(pn,"%s %s %s".format(type2str(l.a.eltType),op,type2str(val.type))); + + if etype == bool { // Since we know that the result type is a boolean, we know // that it either (1) is an operation between bools or (2) uses // a boolean operator (<, <=, etc.) if l.etype == bool && val.type == bool { select op { when "|" { - e.a = l.a | val; + e = l.a | val; } when "&" { - e.a = l.a & val; + e = l.a & val; } when "^" { - e.a = l.a ^ val; + e = l.a ^ val; } when "==" { - e.a = l.a == val; + e = l.a == val; } when "!=" { - e.a = l.a != val; + e = l.a != val; } when "<" { - e.a = l.a:int < val:int; + e = l.a:int < val:int; } when ">" { - e.a = l.a:int > val:int; + e = l.a:int > val:int; } when "<=" { - e.a = l.a:int <= val:int; + e = l.a:int <= val:int; } when ">=" { - e.a = l.a:int >= val:int; + e = l.a:int >= val:int; } when "+" { - e.a = l.a | val; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a | val; } + otherwise do return MsgTuple.error(nie); } } // All types support the same binary operations when the resultant @@ -488,196 +488,168 @@ module BinOp if ((l.etype == real && val.type == bool) || (l.etype == bool && val.type == real)) { select op { when "<" { - e.a = l.a:real < val:real; + e = l.a:real < val:real; } when ">" { - e.a = l.a:real > val:real; + e = l.a:real > val:real; } when "<=" { - e.a = l.a:real <= val:real; + e = l.a:real <= val:real; } when ">=" { - e.a = l.a:real >= val:real; + e = l.a:real >= val:real; } when "==" { - e.a = l.a:real == val:real; + e = l.a:real == val:real; } when "!=" { - e.a = l.a:real != val:real; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:real != val:real; } + otherwise do return MsgTuple.error(nie); } } else { select op { when "<" { - e.a = l.a < val; + e = l.a < val; } when ">" { - e.a = l.a > val; + e = l.a > val; } when "<=" { - e.a = l.a <= val; + e = l.a <= val; } when ">=" { - e.a = l.a >= val; + e = l.a >= val; } when "==" { - e.a = l.a == val; + e = l.a == val; } when "!=" { - e.a = l.a != val; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a != val; } + otherwise do return MsgTuple.error(nie); } } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } // Since we know that both `l` and `r` are of type `int` and that // the resultant type is not bool (checked in first `if`), we know // what operations are supported based on the resultant type else if (l.etype == int && val.type == int) || (l.etype == uint && val.type == uint) { - if e.etype == int || e.etype == uint { + if etype == int || etype == uint { select op { when "+" { - e.a = l.a + val; + e = l.a + val; } when "-" { - e.a = l.a - val; + e = l.a - val; } when "*" { - e.a = l.a * val; + e = l.a * val; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = if val != 0 then li/val else 0; } - when "%" { // modulo - ref ea = e.a; + when "%" { // modulo " + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = if val != 0 then li%val else 0; } when "<<" { if 0 <= val && val < 64 { - e.a = l.a << val; + e = l.a << val; } } when ">>" { if 0 <= val && val < 64 { - e.a = l.a >> val; + e = l.a >> val; } } when "<<<" { - e.a = rotl(l.a, val); + e = rotl(l.a, val); } when ">>>" { - e.a = rotr(l.a, val); + e = rotr(l.a, val); } when "&" { - e.a = l.a & val; + e = l.a & val; } when "|" { - e.a = l.a | val; + e = l.a | val; } when "^" { - e.a = l.a ^ val; + e = l.a ^ val; } when "**" { - e.a= l.a**val; + e= l.a**val; } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - } else if e.etype == real { + } else if etype == real { select op { // True division is the only integer type that would result in a // resultant type of `real` when "/" { - e.a = l.a:real / val:real; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:real / val:real; } + otherwise do return MsgTuple.error(nie); } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } - else if (e.etype == int && val.type == uint) || - (e.etype == uint && val.type == int) { + else if (etype == int && val.type == uint) || + (etype == uint && val.type == int) { select op { when ">>" { if 0 <= val && val < 64 { - e.a = l.a >> val:l.etype; + e = l.a >> val:l.etype; } } when "<<" { if 0 <= val && val < 64 { - e.a = l.a << val:l.etype; + e = l.a << val:l.etype; } } when ">>>" { - e.a = rotr(l.a, val:l.etype); + e = rotr(l.a, val:l.etype); } when "<<<" { - e.a = rotl(l.a, val:l.etype); + e = rotl(l.a, val:l.etype); } when "+" { - e.a = l.a + val:l.etype; + e = l.a + val:l.etype; } when "-" { - e.a = l.a - val:l.etype; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a - val:l.etype; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if (l.etype == bool && val.type == bool) { select op { when ">>" { if(val){ - e.a = l.a:int >> val:int; + e = l.a:int >> val:int; }else{ - e.a = l.a:int; + e = l.a:int; } } when "<<" { if(val){ - e.a = l.a:int << val:int; + e = l.a:int << val:int; }else{ - e.a = l.a:int; + e = l.a:int; } } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } // If either RHS or LHS type is real, the same operations are supported and the // result will always be a `real`, so all 3 of these cases can be shared. @@ -685,154 +657,130 @@ module BinOp || (l.etype == real && val.type == int)) { select op { when "+" { - e.a = l.a + val; + e = l.a + val; } when "-" { - e.a = l.a - val; + e = l.a - val; } when "*" { - e.a = l.a * val; + e = l.a * val; } when "/" { // truediv - e.a = l.a / val; - } + e = l.a / val; + } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li, val); } when "**" { - e.a= l.a**val; + e= l.a**val; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = modHelper(li, val); } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } - else if e.etype == real && ((l.etype == uint && val.type == int) || (l.etype == int && val.type == uint)) { + else if etype == real && ((l.etype == uint && val.type == int) || (l.etype == int && val.type == uint)) { select op { when "+" { - e.a = l.a: real + val: real; + e = l.a: real + val: real; } when "-" { - e.a = l.a: real - val: real; + e = l.a: real - val: real; } when "/" { // truediv - e.a = l.a: real / val: real; + e = l.a: real / val: real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; var la = l.a; [(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li, val:real); } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == uint && val.type == real) || (l.etype == real && val.type == uint)) { select op { when "+" { - e.a = l.a: real + val: real; + e = l.a: real + val: real; } when "-" { - e.a = l.a: real - val: real; + e = l.a: real - val: real; } when "*" { - e.a = l.a: real * val: real; + e = l.a: real * val: real; } when "/" { // truediv - e.a = l.a: real / val: real; + e = l.a: real / val: real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = floorDivisionHelper(li, val); } when "**" { - e.a= l.a: real**val: real; + e= l.a: real**val: real; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = modHelper(li, val); } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == int && val.type == bool) || (l.etype == bool && val.type == int)) { select op { when "+" { // Since we don't know which of `l` or `r` is the int and which is the `bool`, // we can just cast both to int, which will be a noop for the vector that is // already `int` - e.a = l.a:int + val:int; + e = l.a:int + val:int; } when "-" { - e.a = l.a:int - val:int; + e = l.a:int - val:int; } when "*" { - e.a = l.a:int * val:int; + e = l.a:int * val:int; } when ">>" { if 0 <= val:int && val:int < 64 { - e.a = l.a:int >> val:int; + e = l.a:int >> val:int; } } when "<<" { if 0 <= val:int && val:int < 64 { - e.a = l.a:int << val:int; + e = l.a:int << val:int; } } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((l.etype == real && val.type == bool) || (l.etype == bool && val.type == real)) { select op { when "+" { - e.a = l.a:real + val:real; + e = l.a:real + val:real; } when "-" { - e.a = l.a:real - val:real; + e = l.a:real - val:real; } when "*" { - e.a = l.a:real * val:real; - } - otherwise { - var errorMsg = notImplementedError(pn,l.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = l.a:real * val:real; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); + } else { + const errorMsg = unrecognizedTypeError(pn, "("+dtype2str(l.dtype)+","+type2str(val.type)+")"); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return MsgTuple.error(errorMsg); } - var errorMsg = unrecognizedTypeError(pn, "("+dtype2str(l.dtype)+","+dtype2str(dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } proc doBinOpsv(val, r, e, op: string, dtype, rname, pn, st) throws { diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 6a124f574e..b54063ab7a 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -194,104 +194,21 @@ module OperatorMsg return doBinOpvv(l, r, int, op, pn, st); } } - else if binop_dtype_a == bigint && binop_dtype_b == bigint { + else if binop_dtype_a == bigint || binop_dtype_b == bigint { if boolOps.contains(op) { // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == bigint && binop_dtype_b == int { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == bigint && binop_dtype_b == uint { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == bigint && binop_dtype_b == bool { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == int && binop_dtype_b == bigint { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == uint && binop_dtype_b == bigint { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - else if binop_dtype_a == bool && binop_dtype_b == bigint { - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); } // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvv(l, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + const (tmp, max_bits) = doBigIntBinOpvv(l, r, op); + return st.insert(new shared SymEntry(tmp, max_bits)); } else { - writeln("unreachable???"); - var errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); + const errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); return new MsgTuple(errorMsg, MsgType.ERROR); } } - + /* Parse and respond to binopvs message. vs == vector op scalar @@ -305,22 +222,23 @@ module OperatorMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - @arkouda.registerND - proc binopvsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + @arkouda.instantiateAndRegister + proc binopvs(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, + type binop_dtype_a, + type binop_dtype_b, + param array_nd: int + ): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string = ""; // response message - const aname = msgArgs.getValueOf("a"); - const op = msgArgs.getValueOf("op"); - const value = msgArgs.get("value"); + const l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd), + val = msgArgs['value'].toScalar(binop_dtype_b), + op = msgArgs['op'].toScalar(string); - const dtype = str2dtype(msgArgs.getValueOf("dtype")); - var rname = st.nextName(); - var left: borrowed GenSymEntry = getGenericTypedArrayEntry(aname, st); + omLogger.debug(getModuleName(), getRoutineName(), getLineNumber(), + "cmd: %? op: %? left pdarray: %? scalar: %?".format( + cmd,op,st.attrib(msgArgs['a'].val), val)); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "op: %s dtype: %? pdarray: %? scalar: %?".format( - op,dtype,st.attrib(aname),value.getValue())); + const rname = st.nextName(); use Set; // This boolOps set is a filter to determine the output type for the operation. @@ -340,307 +258,140 @@ module OperatorMsg realOps.add("/"); realOps.add("//"); - select (left.dtype, dtype) { - when (DType.Int64, DType.Int64) { - var l = toSymEntry(left,int, nd); - var val = value.getIntValue(); - - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } else if op == "/" { - // True division is the only case in this int, int case - // that results in a `real` symbol table entry. - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - when (DType.Int64, DType.Float64) { - var l = toSymEntry(left,int, nd); - var val = value.getRealValue(); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - when (DType.Float64, DType.Int64) { - var l = toSymEntry(left,real, nd); - var val = value.getIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - when (DType.UInt64, DType.Float64) { - var l = toSymEntry(left,uint, nd); - var val = value.getRealValue(); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + if binop_dtype_a == int && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); + } else if op == "/" { + // True division is the only case in this int, int case + // that results in a `real` symbol table entry. + return doBinOpvs(l, val, real, op, pn, st); } - when (DType.Float64, DType.UInt64) { - var l = toSymEntry(left,real, nd); - var val = value.getUIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, int, op, pn, st); + } else if binop_dtype_a == int && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Float64, DType.Float64) { - var l = toSymEntry(left,real, nd); - var val = value.getRealValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - // For cases where a boolean operand is involved, the only - // possible resultant type is `bool` - when (DType.Bool, DType.Bool) { - var l = toSymEntry(left,bool, nd); - var val = value.getBoolValue(); - if (op == "<<") || (op == ">>") { - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Bool, DType.Int64) { - var l = toSymEntry(left,bool, nd); - var val = value.getIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Int64, DType.Bool) { - var l = toSymEntry(left,int, nd); - var val = value.getBoolValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Bool, DType.Float64) { - var l = toSymEntry(left,bool, nd); - var val = value.getRealValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, real, op, pn, st); + } + // For cases where a boolean operand is involved, the only + // possible resultant type is `bool` + else if binop_dtype_a == bool && binop_dtype_b == bool { + if (op == "<<") || (op == ">>") { + return doBinOpvs(l, val, int, op, pn, st); } - when (DType.Float64, DType.Bool) { - var l = toSymEntry(left,real, nd); - var val = value.getBoolValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, bool, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Bool, DType.UInt64) { - var l = toSymEntry(left,bool, nd); - var val = value.getUIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, int, op, pn, st); + } + else if binop_dtype_a == int && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.UInt64, DType.Bool) { - var l = toSymEntry(left,uint, nd); - var val = value.getBoolValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); + return doBinOpvs(l, val, int, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.UInt64, DType.UInt64) { - var l = toSymEntry(left,uint, nd); - var val = value.getUIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - if op == "/"{ - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } else { - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.UInt64, DType.Int64) { - var l = toSymEntry(left,uint, nd); - var val = value.getIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, l.tupShape, uint); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } + return doBinOpvs(l, val, real, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.Int64, DType.UInt64) { - var l = toSymEntry(left,int, nd); - var val = value.getUIntValue(); - if boolOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, bool); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, l.tupShape, real); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, l.tupShape, int); - return doBinOpvs(l, val, e, op, dtype, rname, pn, st); - } + return doBinOpvs(l, val, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.BigInt, DType.BigInt) { - var l = toSymEntry(left,bigint, nd); - var val = value.getBigIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return doBinOpvs(l, val, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.BigInt, DType.Int64) { - var l = toSymEntry(left,bigint, nd); - var val = value.getIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + if op == "/"{ + return doBinOpvs(l, val, real, op, pn, st); + } else { + return doBinOpvs(l, val, uint, op, pn, st); } - when (DType.BigInt, DType.UInt64) { - var l = toSymEntry(left,bigint, nd); - var val = value.getUIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == uint && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.BigInt, DType.Bool) { - var l = toSymEntry(left,bigint, nd); - var val = value.getBoolValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpvs(l, val, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpvs(l, val, uint, op, pn, st); } - when (DType.Int64, DType.BigInt) { - var l = toSymEntry(left,int, nd); - var val = value.getBigIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == int && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpvs(l, val, bool, op, pn, st); } - when (DType.UInt64, DType.BigInt) { - var l = toSymEntry(left,uint, nd); - var val = value.getBigIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpvs(l, val, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpvs(l, val, int, op, pn, st); } - when (DType.Bool, DType.BigInt) { - var l = toSymEntry(left,bool, nd); - var val = value.getBigIntValue(); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpvs(l, val, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == bigint || binop_dtype_b == bigint { + if boolOps.contains(op) { + // call bigint specific func which returns distr bool array + return st.insert(new shared SymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); } + // call bigint specific func which returns dist bigint array + const (tmp, max_bits) = doBigIntBinOpvs(l, val, op); + return st.insert(new shared SymEntry(tmp, max_bits)); + } else { + const errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return MsgTuple.error(errorMsg); } - var errorMsg = unrecognizedTypeError(pn, "("+dtype2str(left.dtype)+","+dtype2str(dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } /* diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index 3a6540a7d3..c2b3fe3374 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -1687,6 +1687,106 @@ proc ark_binopvv_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); registerFunction('binopvv', ark_binopvv_bigint_bigint_1, 'OperatorMsg', 40); +proc ark_binopvsMsg_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_int_int_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_int_uint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_int_real_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_int_bool_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_int_bigint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_uint_int_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_uint_uint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_uint_real_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_uint_bool_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_uint_bigint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_real_int_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_real_uint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_real_real_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_real_bool_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_real_bigint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bool_int_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bool_uint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bool_real_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bool_bool_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bool_bigint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bigint_int_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bigint_uint_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bigint_real_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bigint_bool_1, 'OperatorMsg', 280); + +proc ark_binopvsMsg_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvsMsg', ark_binopvsMsg_bigint_bigint_1, 'OperatorMsg', 280); + import RandMsg; proc ark_randint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do From a78fd9621c07eba2345f7a5e1a36235fdda69267 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Mon, 16 Sep 2024 15:34:48 -0600 Subject: [PATCH 04/10] fix bigint dispatching in OperatorMsg Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 2 +- src/OperatorMsg.chpl | 8 +- src/registry/Commands.chpl | 150 ++++++++++++++++++------------------- 3 files changed, 82 insertions(+), 78 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 31b8a979eb..132328cdfb 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -540,7 +540,7 @@ def _binop(self, other: pdarray, op: str) -> pdarray: if dt not in DTypes: raise TypeError(f"Unhandled scalar type: {other} ({type(other)})") repMsg = generic_msg( - cmd=f"binopvsMsg<{self.dtype},{dt},{self.ndim}>", + cmd=f"binopvs<{self.dtype},{dt},{self.ndim}>", args={"op": op, "a": self, "value": other}, ) return create_pdarray(repMsg) diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index b54063ab7a..7b7b013fc6 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -194,7 +194,9 @@ module OperatorMsg return doBinOpvv(l, r, int, op, pn, st); } } - else if binop_dtype_a == bigint || binop_dtype_b == bigint { + else if (binop_dtype_a == bigint || binop_dtype_b == bigint) && + !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) + { if boolOps.contains(op) { // call bigint specific func which returns distr bool array return st.insert(new shared SymEntry(doBigIntBinOpvvBoolReturn(l, r, op))); @@ -379,7 +381,9 @@ module OperatorMsg return doBinOpvs(l, val, int, op, pn, st); } } - else if binop_dtype_a == bigint || binop_dtype_b == bigint { + else if (binop_dtype_a == bigint || binop_dtype_b == bigint) && + !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) + { if boolOps.contains(op) { // call bigint specific func which returns distr bool array return st.insert(new shared SymEntry(doBigIntBinOpvsBoolReturn(l, val, op))); diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index c2b3fe3374..8c01a3b726 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -1687,105 +1687,105 @@ proc ark_binopvv_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: return OperatorMsg.binopvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); registerFunction('binopvv', ark_binopvv_bigint_bigint_1, 'OperatorMsg', 40); -proc ark_binopvsMsg_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_int_int_1, 'OperatorMsg', 280); +proc ark_binopvs_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('binopvs', ark_binopvs_int_int_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_int_uint_1, 'OperatorMsg', 280); +proc ark_binopvs_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvs', ark_binopvs_int_uint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_int_real_1, 'OperatorMsg', 280); +proc ark_binopvs_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('binopvs', ark_binopvs_int_real_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_int_bool_1, 'OperatorMsg', 280); +proc ark_binopvs_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvs', ark_binopvs_int_bool_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_int_bigint_1, 'OperatorMsg', 280); +proc ark_binopvs_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvs', ark_binopvs_int_bigint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_uint_int_1, 'OperatorMsg', 280); +proc ark_binopvs_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvs', ark_binopvs_uint_int_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_uint_uint_1, 'OperatorMsg', 280); +proc ark_binopvs_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvs', ark_binopvs_uint_uint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_uint_real_1, 'OperatorMsg', 280); +proc ark_binopvs_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvs', ark_binopvs_uint_real_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_uint_bool_1, 'OperatorMsg', 280); +proc ark_binopvs_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvs', ark_binopvs_uint_bool_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_uint_bigint_1, 'OperatorMsg', 280); +proc ark_binopvs_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvs', ark_binopvs_uint_bigint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_real_int_1, 'OperatorMsg', 280); +proc ark_binopvs_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('binopvs', ark_binopvs_real_int_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_real_uint_1, 'OperatorMsg', 280); +proc ark_binopvs_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvs', ark_binopvs_real_uint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_real_real_1, 'OperatorMsg', 280); +proc ark_binopvs_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('binopvs', ark_binopvs_real_real_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_real_bool_1, 'OperatorMsg', 280); +proc ark_binopvs_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvs', ark_binopvs_real_bool_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_real_bigint_1, 'OperatorMsg', 280); +proc ark_binopvs_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvs', ark_binopvs_real_bigint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bool_int_1, 'OperatorMsg', 280); +proc ark_binopvs_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('binopvs', ark_binopvs_bool_int_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bool_uint_1, 'OperatorMsg', 280); +proc ark_binopvs_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvs', ark_binopvs_bool_uint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bool_real_1, 'OperatorMsg', 280); +proc ark_binopvs_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('binopvs', ark_binopvs_bool_real_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bool_bool_1, 'OperatorMsg', 280); +proc ark_binopvs_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvs', ark_binopvs_bool_bool_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bool_bigint_1, 'OperatorMsg', 280); +proc ark_binopvs_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvs', ark_binopvs_bool_bigint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bigint_int_1, 'OperatorMsg', 280); +proc ark_binopvs_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('binopvs', ark_binopvs_bigint_int_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bigint_uint_1, 'OperatorMsg', 280); +proc ark_binopvs_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopvs', ark_binopvs_bigint_uint_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bigint_real_1, 'OperatorMsg', 280); +proc ark_binopvs_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('binopvs', ark_binopvs_bigint_real_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bigint_bool_1, 'OperatorMsg', 280); +proc ark_binopvs_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopvs', ark_binopvs_bigint_bool_1, 'OperatorMsg', 228); -proc ark_binopvsMsg_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do - return OperatorMsg.binopvsMsg(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); -registerFunction('binopvsMsg', ark_binopvsMsg_bigint_bigint_1, 'OperatorMsg', 280); +proc ark_binopvs_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopvs', ark_binopvs_bigint_bigint_1, 'OperatorMsg', 228); import RandMsg; From c6c2935784b2ec13cdc0867431166d5e3f061fc8 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Tue, 17 Sep 2024 10:09:30 -0600 Subject: [PATCH 05/10] finish refactoring OperatorMsg Signed-off-by: Jeremiah Corrado --- arkouda/pdarrayclass.py | 14 +- src/BinOp.chpl | 270 ++--- src/OperatorMsg.chpl | 2235 +++++++++++++++--------------------- src/registry/Commands.chpl | 300 +++++ 4 files changed, 1321 insertions(+), 1498 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 132328cdfb..2cbed3f396 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -585,8 +585,8 @@ def _r_binop(self, other: pdarray, op: str) -> pdarray: if dt not in DTypes: raise TypeError(f"Unhandled scalar type: {other} ({type(other)})") repMsg = generic_msg( - cmd=f"binopsv{self.ndim}D", - args={"op": op, "dtype": dt, "value": other, "a": self}, + cmd=f"binopsv<{self.dtype},{dt},{self.ndim}>", + args={"op": op, "dtype": dt, "a": self, "value": other}, ) return create_pdarray(repMsg) @@ -777,7 +777,10 @@ def opeq(self, other, op): if isinstance(other, pdarray): if self.shape != other.shape: raise ValueError(f"shape mismatch {self.shape} {other.shape}") - generic_msg(cmd=f"opeqvv{self.ndim}D", args={"op": op, "a": self, "b": other}) + generic_msg( + cmd=f"opeqvv<{self.dtype},{other.dtype},{self.ndim}>", + args={"op": op, "a": self, "b": other} + ) return self # pdarray binop scalar # opeq requires scalar to be cast as pdarray dtype @@ -791,8 +794,9 @@ def opeq(self, other, op): raise TypeError(f"Unhandled scalar type: {other} ({type(other)})") generic_msg( - cmd=f"opeqvs{self.ndim}D", - args={"op": op, "a": self, "dtype": self.dtype.name, "value": self.format_other(other)}, + # TODO: does opeqvs really need to select over pairs of dtypes? + cmd=f"opeqvs<{self.dtype},{self.dtype},{self.ndim}>", + args={"op": op, "a": self, "value": self.format_other(other)}, ) return self diff --git a/src/BinOp.chpl b/src/BinOp.chpl index a93a08ca3f..92703d1f6a 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -783,48 +783,47 @@ module BinOp } } - proc doBinOpsv(val, r, e, op: string, dtype, rname, pn, st) throws { - if e.etype == bool { + proc doBinOpsv(val, r, type etype, op: string, pn, st) throws { + var e = makeDistArray((...r.tupShape), etype); + const nie = notImplementedError(pn,"%s %s %s".format(type2str(val.type),op,type2str(r.a.eltType))); + + if etype == bool { // Since we know that the result type is a boolean, we know // that it either (1) is an operation between bools or (2) uses // a boolean operator (<, <=, etc.) if r.etype == bool && val.type == bool { select op { when "|" { - e.a = val | r.a; + e = val | r.a; } when "&" { - e.a = val & r.a; + e = val & r.a; } when "^" { - e.a = val ^ r.a; + e = val ^ r.a; } when "==" { - e.a = val == r.a; + e = val == r.a; } when "!=" { - e.a = val != r.a; + e = val != r.a; } when "<" { - e.a = val:int < r.a:int; + e = val:int < r.a:int; } when ">" { - e.a = val:int > r.a:int; + e = val:int > r.a:int; } when "<=" { - e.a = val:int <= r.a:int; + e = val:int <= r.a:int; } when ">=" { - e.a = val:int >= r.a:int; + e = val:int >= r.a:int; } when "+" { - e.a = val | r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val | r.a; } + otherwise do return MsgTuple.error(nie); } } // All types support the same binary operations when the resultant @@ -834,111 +833,102 @@ module BinOp if ((r.etype == real && val.type == bool) || (r.etype == bool && val.type == real)) { select op { when "<" { - e.a = val:real < r.a:real; + e = val:real < r.a:real; } when ">" { - e.a = val:real > r.a:real; + e = val:real > r.a:real; } when "<=" { - e.a = val:real <= r.a:real; + e = val:real <= r.a:real; } when ">=" { - e.a = val:real >= r.a:real; + e = val:real >= r.a:real; } when "==" { - e.a = val:real == r.a:real; + e = val:real == r.a:real; } when "!=" { - e.a = val:real != r.a:real; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val:real != r.a:real; } + otherwise do return MsgTuple.error(nie); } } else { select op { when "<" { - e.a = val < r.a; + e = val < r.a; } when ">" { - e.a = val > r.a; + e = val > r.a; } when "<=" { - e.a = val <= r.a; + e = val <= r.a; } when ">=" { - e.a = val >= r.a; + e = val >= r.a; } when "==" { - e.a = val == r.a; + e = val == r.a; } when "!=" { - e.a = val != r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val != r.a; } + otherwise do return MsgTuple.error(nie); } } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } // Since we know that both `l` and `r` are of type `int` and that // the resultant type is not bool (checked in first `if`), we know // what operations are supported based on the resultant type else if (r.etype == int && val.type == int) || (r.etype == uint && val.type == uint) { - if e.etype == int || e.etype == uint { + if etype == int || etype == uint { select op { when "+" { - e.a = val + r.a; + e = val + r.a; } when "-" { - e.a = val - r.a; + e = val - r.a; } when "*" { - e.a = val * r.a; + e = val * r.a; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val/ri else 0; } - when "%" { // modulo - ref ea = e.a; + when "%" { // modulo " + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val%ri else 0; } when "<<" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < 64) then ei = val << ri; } when ">>" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < 64) then ei = val >> ri; } when "<<<" { - e.a = rotl(val, r.a); + e = rotl(val, r.a); } when ">>>" { - e.a = rotr(val, r.a); + e = rotr(val, r.a); } when "&" { - e.a = val & r.a; + e = val & r.a; } when "|" { - e.a = val | r.a; + e = val | r.a; } when "^" { - e.a = val ^ r.a; + e = val ^ r.a; } when "**" { if || reduce (r.a<0){ @@ -946,65 +936,50 @@ module BinOp omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); return new MsgTuple(errorMsg, MsgType.ERROR); } - e.a= val**r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e= val**r.a; } + otherwise do return MsgTuple.error(nie); } - } else if e.etype == real { + } else if etype == real { select op { // True division is the only integer type that would result in a // resultant type of `real` when "/" { - e.a = val:real / r.a:real; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val:real / r.a:real; } - + otherwise do return MsgTuple.error(nie); } } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } - else if (e.etype == int && val.type == uint) || - (e.etype == uint && val.type == int) { + else if (etype == int && val.type == uint) || + (etype == uint && val.type == int) { select op { when ">>" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if ri:uint < 64 then ei = val:r.etype >> ri; } when "<<" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if ri:uint < 64 then ei = val:r.etype << ri; } when ">>>" { - e.a = rotr(val:r.etype, r.a); + e = rotr(val:r.etype, r.a); } when "<<<" { - e.a = rotl(val:r.etype, r.a); + e = rotl(val:r.etype, r.a); } when "+" { - e.a = val:r.etype + r.a; + e = val:r.etype + r.a; } when "-" { - e.a = val:r.etype - r.a; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val:r.etype - r.a; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } // If either RHS or LHS type is real, the same operations are supported and the // result will always be a `real`, so all 3 of these cases can be shared. @@ -1012,174 +987,145 @@ module BinOp || (r.etype == real && val.type == int)) { select op { when "+" { - e.a = val + r.a; + e = val + r.a; } when "-" { - e.a = val - r.a; + e = val - r.a; } when "*" { - e.a = val * r.a; + e = val * r.a; } when "/" { // truediv - e.a = val:real / r.a:real; + e = val:real / r.a:real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(val:real, ri); } when "**" { - e.a= val**r.a; + e= val**r.a; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = modHelper(val:real, ri); } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } - else if e.etype == real && ((r.etype == uint && val.type == int) || (r.etype == int && val.type == uint)) { + else if etype == real && ((r.etype == uint && val.type == int) || (r.etype == int && val.type == uint)) { select op { when "+" { - e.a = val:real + r.a:real; + e = val:real + r.a:real; } when "-" { - e.a = val:real - r.a:real; + e = val:real - r.a:real; } when "/" { // truediv - e.a = val:real / r.a:real; + e = val:real / r.a:real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; var ra = r.a; [(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(val:real, ri); } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((r.etype == uint && val.type == real) || (r.etype == real && val.type == uint)) { select op { when "+" { - e.a = val:real + r.a:real; + e = val:real + r.a:real; } when "-" { - e.a = val:real - r.a:real; + e = val:real - r.a:real; } when "*" { - e.a = val:real * r.a:real; + e = val:real * r.a:real; } when "/" { // truediv - e.a = val:real / r.a:real; + e = val:real / r.a:real; } when "//" { // floordiv - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = floorDivisionHelper(val:real, ri); } when "**" { - e.a= val:real**r.a:real; + e= val:real**r.a:real; } - when "%" { - ref ea = e.a; + when "%" { // " + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = modHelper(val:real, ri); } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((r.etype == int && val.type == bool) || (r.etype == bool && val.type == int)) { select op { when "+" { // Since we don't know which of `l` or `r` is the int and which is the `bool`, // we can just cast both to int, which will be a noop for the vector that is // already `int` - e.a = val:int + r.a:int; + e = val:int + r.a:int; } when "-" { - e.a = val:int - r.a:int; + e = val:int - r.a:int; } when "*" { - e.a = val:int * r.a:int; + e = val:int * r.a:int; } when ">>" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < 64) then ei = val:int >> ri:int; } when "<<" { - ref ea = e.a; + ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] if (0 <= ri && ri < 64) then ei = val:int << ri:int; } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if ((r.etype == real && val.type == bool) || (r.etype == bool && val.type == real)) { select op { when "+" { - e.a = val:real + r.a:real; + e = val:real + r.a:real; } when "-" { - e.a = val:real - r.a:real; + e = val:real - r.a:real; } when "*" { - e.a = val:real * r.a:real; - } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + e = val:real * r.a:real; } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); } else if (r.etype == bool && val.type == bool) { select op { when "<<" { if(val){ - e.a = val:int << r.a:int; + e = val:int << r.a:int; } } when ">>" { if(val){ - e.a = val:int >> r.a:int; + e = val:int >> r.a:int; } } - otherwise { - var errorMsg = notImplementedError(pn,dtype,op,r.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + otherwise do return MsgTuple.error(nie); } - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return st.insert(new shared SymEntry(e)); + } else { + const errorMsg = unrecognizedTypeError(pn, "("+type2str(val.type)+","+type2str(r.a.eltType)+")"); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(errorMsg, MsgType.ERROR); } - var errorMsg = unrecognizedTypeError(pn, "("+dtype2str(dtype)+","+dtype2str(r.dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } proc doBigIntBinOpvv(l, r, op: string) throws { diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 7b7b013fc6..036b965d7c 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -411,22 +411,21 @@ module OperatorMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - @arkouda.registerND - proc binopsvMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + @arkouda.instantiateAndRegister + proc binopsv(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, + type binop_dtype_a, + type binop_dtype_b, + param array_nd: int + ): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string = ""; // response message - const op = msgArgs.getValueOf("op"); - const aname = msgArgs.getValueOf("a"); - const value = msgArgs.get("value"); + const r = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd), + val = msgArgs['value'].toScalar(binop_dtype_b), + op = msgArgs['op'].toScalar(string); - var dtype = str2dtype(msgArgs.getValueOf("dtype")); - var rname = st.nextName(); - var right: borrowed GenSymEntry = getGenericTypedArrayEntry(aname, st); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "command = %? op = %? scalar dtype = %? scalar = %? pdarray = %?".format( - cmd,op,dtype2str(dtype),value,st.attrib(aname))); + "cmd: %? op = %? scalar dtype = %? scalar = %? pdarray = %?".format( + cmd,op,type2str(binop_dtype_b),msgArgs['value'].val,st.attrib(msgArgs['a'].val))); use Set; // This boolOps set is a filter to determine the output type for the operation. @@ -439,314 +438,148 @@ module OperatorMsg boolOps.add(">="); boolOps.add("=="); boolOps.add("!="); - + var realOps: set(string); realOps.add("+"); realOps.add("-"); realOps.add("/"); realOps.add("//"); - select (dtype, right.dtype) { - when (DType.Int64, DType.Int64) { - var val = value.getIntValue(); - var r = toSymEntry(right,int, nd); - - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } else if op == "/" { - // True division is the only case in this int, int case - // that results in a `real` symbol table entry. - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, int); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - when (DType.Int64, DType.Float64) { - var val = value.getIntValue(); - var r = toSymEntry(right,real, nd); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - when (DType.Float64, DType.Int64) { - var val = value.getRealValue(); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - when (DType.UInt64, DType.Float64) { - var val = value.getUIntValue(); - var r = toSymEntry(right,real, nd); - // Only two possible resultant types are `bool` and `real` - // for this case - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + if binop_dtype_a == int && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); + } else if op == "/" { + // True division is the only case in this int, int case + // that results in a `real` symbol table entry. + return doBinOpsv(val, r, real, op, pn, st); } - when (DType.Float64, DType.UInt64) { - var val = value.getRealValue(); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, int, op, pn, st); + } + else if binop_dtype_a == int && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Float64, DType.Float64) { - var val = value.getRealValue(); - var r = toSymEntry(right,real, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - // For cases where a boolean operand is involved, the only - // possible resultant type is `bool` - when (DType.Bool, DType.Bool) { - var val = value.getBoolValue(); - var r = toSymEntry(right,bool, nd); - if (op == "<<") || (op == ">>") { - var e = st.addEntry(rname, r.tupShape, int); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == real { + // Only two possible resultant types are `bool` and `real` + // for this case + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Bool, DType.Int64) { - var val = value.getBoolValue(); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, int); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Int64, DType.Bool) { - var val = value.getIntValue(); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, int); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Bool, DType.Float64) { - var val = value.getBoolValue(); - var r = toSymEntry(right,real, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, real, op, pn, st); + } + // For cases where a boolean operand is involved, the only + // possible resultant type is `bool` + else if binop_dtype_a == bool && binop_dtype_b == bool { + if (op == "<<") || (op == ">>") { + return doBinOpsv(val, r, int, op, pn, st); } - when (DType.Float64, DType.Bool) { - var val = value.getRealValue(); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, bool, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Bool, DType.UInt64) { - var val = value.getBoolValue(); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, uint); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, int, op, pn, st); + } + else if binop_dtype_a == int && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.UInt64, DType.Bool) { - var val = value.getUIntValue(); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - var e = st.addEntry(rname, r.tupShape, uint); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); + return doBinOpsv(val, r, int, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == real { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.UInt64, DType.UInt64) { - var val = value.getUIntValue(); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - if op == "/"{ - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } else { - var e = st.addEntry(rname, r.tupShape, uint); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == real && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.UInt64, DType.Int64) { - var val = value.getUIntValue(); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, r.tupShape, uint); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } + return doBinOpsv(val, r, real, op, pn, st); + } + else if binop_dtype_a == bool && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.Int64, DType.UInt64) { - var val = value.getIntValue(); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, bool); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } - // +, -, /, // both result in real outputs to match NumPy - if realOps.contains(op) { - var e = st.addEntry(rname, r.tupShape, real); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } else { - // isn't +, -, /, // so we can use LHS to determine type - var e = st.addEntry(rname, r.tupShape, int); - return doBinOpsv(val, r, e, op, dtype, rname, pn, st); - } + return doBinOpsv(val, r, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == bool { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.BigInt, DType.BigInt) { - var val = value.getBigIntValue(); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + return doBinOpsv(val, r, uint, op, pn, st); + } + else if binop_dtype_a == uint && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.BigInt, DType.Int64) { - var val = value.getBigIntValue(); - var r = toSymEntry(right,int, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + if op == "/"{ + return doBinOpsv(val, r, real, op, pn, st); + } else { + return doBinOpsv(val, r, uint, op, pn, st); } - when (DType.BigInt, DType.UInt64) { - var val = value.getBigIntValue(); - var r = toSymEntry(right,uint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == uint && binop_dtype_b == int { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.BigInt, DType.Bool) { - var val = value.getBigIntValue(); - var r = toSymEntry(right,bool, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpsv(val, r, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpsv(val, r, uint, op, pn, st); } - when (DType.Int64, DType.BigInt) { - var val = value.getIntValue(); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if binop_dtype_a == int && binop_dtype_b == uint { + if boolOps.contains(op) { + return doBinOpsv(val, r, bool, op, pn, st); } - when (DType.UInt64, DType.BigInt) { - var val = value.getUIntValue(); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + // +, -, /, // both result in real outputs to match NumPy + if realOps.contains(op) { + return doBinOpsv(val, r, real, op, pn, st); + } else { + // isn't +, -, /, // so we can use LHS to determine type + return doBinOpsv(val, r, int, op, pn, st); } - when (DType.Bool, DType.BigInt) { - var val = value.getBoolValue(); - var r = toSymEntry(right,bigint, nd); - if boolOps.contains(op) { - // call bigint specific func which returns distr bool array - var e = st.addEntry(rname, createSymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); - } - // call bigint specific func which returns dist bigint array - var (tmp, max_bits) = doBigIntBinOpsv(val, r, op); - var e = st.addEntry(rname, createSymEntry(tmp, max_bits)); - var repMsg = "created %s".format(st.attrib(rname)); - return new MsgTuple(repMsg, MsgType.NORMAL); + } + else if (binop_dtype_a == bigint || binop_dtype_b == bigint) && + !isRealType(binop_dtype_a) && !isRealType(binop_dtype_b) { + if boolOps.contains(op) { + return st.insert(new shared SymEntry(doBigIntBinOpsvBoolReturn(val, r, op))); } + // call bigint specific func which returns dist bigint array + const (tmp, max_bits) = doBigIntBinOpsv(val, r, op); + return st.insert(new shared SymEntry(tmp, max_bits)); + } else { + const errorMsg = unrecognizedTypeError(pn, "("+type2str(binop_dtype_a)+","+type2str(binop_dtype_b)+")"); + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return MsgTuple.error(errorMsg); } - var errorMsg = unrecognizedTypeError(pn, "("+dtype2str(dtype)+","+dtype2str(right.dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); } /* @@ -762,587 +595,458 @@ module OperatorMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - @arkouda.registerND - proc opeqvvMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + @arkouda.instantiateAndRegister + proc opeqvv(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, + type binop_dtype_a, + type binop_dtype_b, + param array_nd: int + ): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string; // response message - const op = msgArgs.getValueOf("op"); - const aname = msgArgs.getValueOf("a"); - const bname = msgArgs.getValueOf("b"); - - // retrieve left and right pdarray objects - var left: borrowed GenSymEntry = getGenericTypedArrayEntry(aname, st); - var right: borrowed GenSymEntry = getGenericTypedArrayEntry(bname, st); + var l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd); + const r = st[msgArgs['b']]: borrowed SymEntry(binop_dtype_b, array_nd), + op = msgArgs['op'].toScalar(string), + nie = notImplementedError(pn,type2str(binop_dtype_a),op,type2str(binop_dtype_b)); omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), "cmd: %s op: %s left pdarray: %s right pdarray: %s".format(cmd,op, - st.attrib(aname),st.attrib(bname))); + st.attrib(msgArgs['a'].val),st.attrib(msgArgs['b'].val))); - select (left.dtype, right.dtype) { - when (DType.Int64, DType.Int64) { - var l = toSymEntry(left,int, nd); - var r = toSymEntry(right,int, nd); - select op { - when "+=" { l.a += r.a; } - when "-=" { l.a -= r.a; } - when "*=" { l.a *= r.a; } - when ">>=" { l.a >>= r.a;} - when "<<=" { l.a <<= r.a;} - when "//=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; - }//floordiv - when "%=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; - } - when "**=" { - if || reduce (r.a<0){ - var errorMsg = "Attempt to exponentiate base of type Int64 to negative exponent"; - return new MsgTuple(errorMsg, MsgType.ERROR); - } - else{ l.a **= r.a; } - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + if binop_dtype_a == int && binop_dtype_b == int { + select op { + when "+=" { l.a += r.a; } + when "-=" { l.a -= r.a; } + when "*=" { l.a *= r.a; } + when ">>=" { l.a >>= r.a;} + when "<<=" { l.a <<= r.a;} + when "//=" { + //l.a /= r.a; + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; + }//floordiv + when "%=" { + //l.a /= r.a; + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; } - } - when (DType.Int64, DType.UInt64) { - // The result of operations between int and uint are float by default which doesn't fit in either type - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.Int64, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.Int64, DType.Bool) { - var l = toSymEntry(left, int, nd); - var r = toSymEntry(right, bool, nd); - select op { - when "+=" {l.a += r.a:int;} - when "-=" {l.a -= r.a:int;} - when "*=" {l.a *= r.a:int;} - when ">>=" { l.a >>= r.a:int;} - when "<<=" { l.a <<= r.a:int;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + when "**=" { + if || reduce (r.a<0){ + var errorMsg = "Attempt to exponentiate base of type Int64 to negative exponent"; return new MsgTuple(errorMsg, MsgType.ERROR); } + else{ l.a **= r.a; } } + otherwise do return MsgTuple.error(nie); } - when (DType.Int64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.UInt64, DType.Int64) { - // The result of operations between int and uint are float by default which doesn't fit in either type - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == int && binop_dtype_b == bool { + select op { + when "+=" {l.a += r.a:int;} + when "-=" {l.a -= r.a:int;} + when "*=" {l.a *= r.a:int;} + when ">>=" { l.a >>= r.a:int;} + when "<<=" { l.a <<= r.a:int;} + otherwise do return MsgTuple.error(nie); } - when (DType.UInt64, DType.UInt64) { - var l = toSymEntry(left,uint, nd); - var r = toSymEntry(right,uint, nd); - select op { - when "+=" { l.a += r.a; } - when "-=" { - l.a -= r.a; - } - when "*=" { l.a *= r.a; } - when "//=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; - }//floordiv - when "%=" { - //l.a /= r.a; - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; - } - when "**=" { - l.a **= r.a; - } - when ">>=" { l.a >>= r.a;} - when "<<=" { l.a <<= r.a;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == uint && binop_dtype_b == uint { + select op { + when "+=" { l.a += r.a; } + when "-=" { + l.a -= r.a; } - } - when (DType.UInt64, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.UInt64, DType.Bool) { - var l = toSymEntry(left, uint, nd); - var r = toSymEntry(right, bool, nd); - select op { - when "+=" {l.a += r.a:uint;} - when "-=" {l.a -= r.a:uint;} - when "*=" {l.a *= r.a:uint;} - when ">>=" { l.a >>= r.a:uint;} - when "<<=" { l.a <<= r.a:uint;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + when "*=" { l.a *= r.a; } + when "//=" { + //l.a /= r.a; + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = if ri != 0 then li/ri else 0; + }//floordiv + when "%=" { + //l.a /= r.a; + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = if ri != 0 then li%ri else 0; + } + when "**=" { + l.a **= r.a; } + when ">>=" { l.a >>= r.a;} + when "<<=" { l.a <<= r.a;} + otherwise do return MsgTuple.error(nie); } - when (DType.UInt64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == uint && binop_dtype_b == bool { + select op { + when "+=" {l.a += r.a:uint;} + when "-=" {l.a -= r.a:uint;} + when "*=" {l.a *= r.a:uint;} + when ">>=" { l.a >>= r.a:uint;} + when "<<=" { l.a <<= r.a:uint;} + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Int64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,int, nd); - - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a:real;} //truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == int { + select op { + when "+=" {l.a += r.a;} + when "-=" {l.a -= r.a;} + when "*=" {l.a *= r.a;} + when "/=" {l.a /= r.a:real;} //truediv + when "//=" { //floordiv + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); } - } - when (DType.Float64, DType.UInt64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,uint, nd); - - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a:real;} //truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + when "**=" { l.a **= r.a; } + when "%=" { + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = modHelper(li, ri); } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Float64) { - var l = toSymEntry(left,real, nd); - var r = toSymEntry(right,real, nd); - select op { - when "+=" {l.a += r.a;} - when "-=" {l.a -= r.a;} - when "*=" {l.a *= r.a;} - when "/=" {l.a /= r.a;}//truediv - when "//=" { //floordiv - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); - } - when "**=" { l.a **= r.a; } - when "%=" { - ref la = l.a; - ref ra = r.a; - [(li,ri) in zip(la,ra)] li = modHelper(li, ri); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == uint { + select op { + when "+=" {l.a += r.a;} + when "-=" {l.a -= r.a;} + when "*=" {l.a *= r.a;} + when "/=" {l.a /= r.a:real;} //truediv + when "//=" { //floordiv + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); } + when "**=" { l.a **= r.a; } + when "%=" { + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = modHelper(li, ri); + } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Bool) { - var l = toSymEntry(left, real, nd); - var r = toSymEntry(right, bool, nd); - select op { - when "+=" {l.a += r.a:real;} - when "-=" {l.a -= r.a:real;} - when "*=" {l.a *= r.a:real;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == real { + select op { + when "+=" {l.a += r.a;} + when "-=" {l.a -= r.a;} + when "*=" {l.a *= r.a;} + when "/=" {l.a /= r.a;}//truediv + when "//=" { //floordiv + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = floorDivisionHelper(li, ri); + } + when "**=" { l.a **= r.a; } + when "%=" { + ref la = l.a; + ref ra = r.a; + [(li,ri) in zip(la,ra)] li = modHelper(li, ri); } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == real && binop_dtype_b == bool { + select op { + when "+=" {l.a += r.a:real;} + when "-=" {l.a -= r.a:real;} + when "*=" {l.a *= r.a:real;} + otherwise do return MsgTuple.error(nie); } - when (DType.Bool, DType.Bool) { - var l = toSymEntry(left, bool, nd); - var r = toSymEntry(right, bool, nd); - select op { - when "|=" {l.a |= r.a;} - when "&=" {l.a &= r.a;} - when "^=" {l.a ^= r.a;} - when "+=" {l.a |= r.a;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + } + else if binop_dtype_a == bool && binop_dtype_b == bool { + select op { + when "|=" {l.a |= r.a;} + when "&=" {l.a &= r.a;} + when "^=" {l.a ^= r.a;} + when "+=" {l.a |= r.a;} + otherwise do return MsgTuple.error(nie); } - when (DType.BigInt, DType.Int64) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,int, nd); - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } + } + else if binop_dtype_a == bigint && binop_dtype_b == int { + ref la = l.a; + ref ra = r.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li += ri; + if has_max_bits { + li &= local_max_size; } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li -= ri; + if has_max_bits { + li &= local_max_size; } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li *= ri; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "//=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + li /= ri; } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + else { + li = 0:bigint; } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } + if has_max_bits { + li &= local_max_size; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + when "%=" { + // we can't use li %= ri because this can result in negatives + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + mod(li, li, ri); + } + else { + li = 0:bigint; + } + if has_max_bits { + li &= local_max_size; } } - } - when (DType.BigInt, DType.UInt64) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,uint, nd); - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; + } + when "**=" { + if || reduce (ra<0) { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + powMod(li, li, ri, local_max_size + 1); } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + else { + forall (li, ri) in zip(la, ra) { + li **= ri:uint; } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + otherwise do return MsgTuple.error(nie); + } + } + else if binop_dtype_a == bigint && binop_dtype_b == uint { + ref la = l.a; + ref ra = r.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li += ri; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li -= ri; + if has_max_bits { + li &= local_max_size; } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li *= ri; + if has_max_bits { + li &= local_max_size; } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } + } + } + when "//=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + li /= ri; + } + else { + li = 0:bigint; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + if has_max_bits { + li &= local_max_size; } } - } - when (DType.BigInt, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.BigInt, DType.Bool) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,bool, nd); - ref la = l.a; - var ra = r.a:bigint; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; + } + when "%=" { + // we can't use li %= ri because this can result in negatives + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + mod(li, li, ri); + } + else { + li = 0:bigint; + } + if has_max_bits { + li &= local_max_size; + } + } + } + when "**=" { + if || reduce (ra<0) { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + powMod(li, li, ri, local_max_size + 1); } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + else { + forall (li, ri) in zip(la, ra) { + li **= ri:uint; } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + otherwise do return MsgTuple.error(nie); + } + } + else if binop_dtype_a == bigint && binop_dtype_b == bool { + ref la = l.a; + var ra = r.a:bigint; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li += ri; + if has_max_bits { + li &= local_max_size; + } + } + } + when "-=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li -= ri; + if has_max_bits { + li &= local_max_size; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + when "*=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li *= ri; + if has_max_bits { + li &= local_max_size; } } + } + otherwise do return MsgTuple.error(nie); } - when (DType.BigInt, DType.BigInt) { - var l = toSymEntry(left,bigint, nd); - var r = toSymEntry(right,bigint, nd); - ref la = l.a; - ref ra = r.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li += ri; - if has_max_bits { - li &= local_max_size; - } - } + } + else if binop_dtype_a == bigint && binop_dtype_b == bigint { + ref la = l.a; + ref ra = r.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li += ri; + if has_max_bits { + li &= local_max_size; } - when "-=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li -= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li -= ri; + if has_max_bits { + li &= local_max_size; } - when "*=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - li *= ri; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + li *= ri; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - li /= ri; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "//=" { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + li /= ri; } - when "%=" { - // we can't use li %= ri because this can result in negatives - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - if ri != 0 { - mod(li, li, ri); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + else { + li = 0:bigint; } - when "**=" { - if || reduce (ra<0) { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { - powMod(li, li, ri, local_max_size + 1); - } - } - else { - forall (li, ri) in zip(la, ra) { - li **= ri:uint; - } - } + if has_max_bits { + li &= local_max_size; + } + } + } + when "%=" { + // we can't use li %= ri because this can result in negatives + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + if ri != 0 { + mod(li, li, ri); + } + else { + li = 0:bigint; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,right.dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + if has_max_bits { + li &= local_max_size; } } + } + when "**=" { + if || reduce (ra<0) { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } + if has_max_bits { + forall (li, ri) in zip(la, ra) with (var local_max_size = max_size) { + powMod(li, li, ri, local_max_size + 1); + } + } + else { + forall (li, ri) in zip(la, ra) { + li **= ri:uint; + } + } + } + otherwise do return MsgTuple.error(nie); } - otherwise { - var errorMsg = unrecognizedTypeError(pn, - "("+dtype2str(left.dtype)+","+dtype2str(right.dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } - repMsg = "opeqvv success"; - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); + } else { + return MsgTuple.error(nie); + } + + return MsgTuple.success(); } /* @@ -1358,581 +1062,450 @@ module OperatorMsg :returns: (MsgTuple) :throws: `UndefinedSymbolError(name)` */ - @arkouda.registerND - proc opeqvsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, param nd: int): MsgTuple throws { + @arkouda.instantiateAndRegister + proc opeqvs(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, + type binop_dtype_a, + type binop_dtype_b, + param array_nd: int + ): MsgTuple throws { param pn = Reflection.getRoutineName(); - var repMsg: string; // response message - - const op = msgArgs.getValueOf("op"); - const aname = msgArgs.getValueOf("a"); - const value = msgArgs.get("value"); - var dtype = str2dtype(msgArgs.getValueOf("dtype")); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "cmd: %s op: %s aname: %s dtype: %s scalar: %s".format( - cmd,op,aname,dtype2str(dtype),value.getValue())); + var l = st[msgArgs['a']]: borrowed SymEntry(binop_dtype_a, array_nd); + const val = msgArgs['value'].toScalar(binop_dtype_b), + op = msgArgs['op'].toScalar(string), + nie = notImplementedError(pn,type2str(binop_dtype_a),op,type2str(binop_dtype_b)); - var left: borrowed GenSymEntry = getGenericTypedArrayEntry(aname, st); - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), - "op: %? pdarray: %? scalar: %?".format(op,st.attrib(aname),value.getValue())); - select (left.dtype, dtype) { - when (DType.Int64, DType.Int64) { - var l = toSymEntry(left,int, nd); - var val = value.getIntValue(); - select op { - when "+=" { l.a += val; } - when "-=" { l.a -= val; } - when "*=" { l.a *= val; } - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - when "//=" { - if val != 0 {l.a /= val;} else {l.a = 0;} - }//floordiv - when "%=" { - if val != 0 {l.a %= val;} else {l.a = 0;} - } - when "**=" { - if val<0 { - var errorMsg = "Attempt to exponentiate base of type Int64 to negative exponent"; - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(), - errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - else{ l.a **= val; } + "op: %? pdarray: %? scalar: %?".format(op,st.attrib(msgArgs['a'].val),val)); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + if binop_dtype_a == int && binop_dtype_b == int { + select op { + when "+=" { l.a += val; } + when "-=" { l.a -= val; } + when "*=" { l.a *= val; } + when ">>=" { l.a >>= val; } + when "<<=" { l.a <<= val; } + when "//=" { + if val != 0 {l.a /= val;} else {l.a = 0;} + }//floordiv + when "%=" { + if val != 0 {l.a %= val;} else {l.a = 0;} } - } - when (DType.Int64, DType.UInt64) { - var l = toSymEntry(left,int, nd); - var val = value.getUIntValue(); - select op { - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - otherwise { - // The result of operations between int and uint are float by default which doesn't fit in either type - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + when "**=" { + if val<0 { + var errorMsg = "Attempt to exponentiate base of type int64 to negative exponent"; + omLogger.error(getModuleName(),getRoutineName(),getLineNumber(), + errorMsg); return new MsgTuple(errorMsg, MsgType.ERROR); } + else{ l.a **= val; } + } + otherwise do return MsgTuple.error(nie); } - when (DType.Int64, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == int && binop_dtype_b == uint { + select op { + when ">>=" { l.a >>= val; } + when "<<=" { l.a <<= val; } + otherwise do return MsgTuple.error(nie); } - when (DType.Int64, DType.Bool) { - var l = toSymEntry(left, int, nd); - var val = value.getBoolValue(); - select op { - when "+=" {l.a += val:int;} - when "-=" {l.a -= val:int;} - when "*=" {l.a *= val:int;} - when ">>=" {l.a >>= val:int; } - when "<<=" {l.a <<= val:int; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + } + else if binop_dtype_a == int && binop_dtype_b == bool { + select op { + when "+=" {l.a += val:int;} + when "-=" {l.a -= val:int;} + when "*=" {l.a *= val:int;} + when ">>=" {l.a >>= val:int; } + when "<<=" {l.a <<= val:int; } + otherwise do return MsgTuple.error(nie); } - when (DType.Int64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == uint && binop_dtype_b == int { + select op { + when ">>=" { l.a >>= val; } + when "<<=" { l.a <<= val; } + otherwise do return MsgTuple.error(nie); } - when (DType.UInt64, DType.Int64) { - var l = toSymEntry(left, uint, nd); - var val = value.getIntValue(); - select op { - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - otherwise { - // The result of operations between int and uint are float by default which doesn't fit in either type - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == uint && binop_dtype_b == uint { + select op { + when "+=" { l.a += val; } + when "-=" { + l.a -= val; } - } - when (DType.UInt64, DType.UInt64) { - var l = toSymEntry(left,uint, nd); - var val = value.getUIntValue(); - select op { - when "+=" { l.a += val; } - when "-=" { - l.a -= val; - } - when "*=" { l.a *= val; } - when "//=" { - if val != 0 {l.a /= val;} else {l.a = 0;} - }//floordiv - when "%=" { - if val != 0 {l.a %= val;} else {l.a = 0;} - } - when "**=" { - l.a **= val; - } - when ">>=" { l.a >>= val; } - when "<<=" { l.a <<= val; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + when "*=" { l.a *= val; } + when "//=" { + if val != 0 {l.a /= val;} else {l.a = 0;} + }//floordiv + when "%=" { + if val != 0 {l.a %= val;} else {l.a = 0;} } - } - when (DType.UInt64, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.UInt64, DType.Bool) { - var l = toSymEntry(left, uint, nd); - var val = value.getBoolValue(); - select op { - when "+=" {l.a += val:uint;} - when "-=" {l.a -= val:uint;} - when "*=" {l.a *= val:uint;} - when ">>=" { l.a >>= val:uint;} - when "<<=" { l.a <<= val:uint;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + when "**=" { + l.a **= val; } + when ">>=" { l.a >>= val; } + when "<<=" { l.a <<= val; } + otherwise do return MsgTuple.error(nie); } - when (DType.Bool, DType.Bool) { - var l = toSymEntry(left, bool, nd); - var val = value.getBoolValue(); - select op { - when "+=" {l.a |= val;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - } + } + else if binop_dtype_a == uint && binop_dtype_b == bool { + select op { + when "+=" {l.a += val:uint;} + when "-=" {l.a -= val:uint;} + when "*=" {l.a *= val:uint;} + when ">>=" { l.a >>= val:uint;} + when "<<=" { l.a <<= val:uint;} + otherwise do return MsgTuple.error(nie); } - when (DType.UInt64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == bool && binop_dtype_b == bool { + select op { + when "+=" {l.a |= val;} + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Int64) { - var l = toSymEntry(left,real, nd); - var val = value.getIntValue(); - select op { - when "+=" {l.a += val;} - when "-=" {l.a -= val;} - when "*=" {l.a *= val;} - when "/=" {l.a /= val:real;} //truediv - when "//=" { //floordiv - ref la = l.a; - [li in la] li = floorDivisionHelper(li, val); - } - when "**=" { l.a **= val; } - when "%=" { - ref la = l.a; - [li in la] li = modHelper(li, val); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == int { + select op { + when "+=" {l.a += val;} + when "-=" {l.a -= val;} + when "*=" {l.a *= val;} + when "/=" {l.a /= val:real;} //truediv + when "//=" { //floordiv + ref la = l.a; + [li in la] li = floorDivisionHelper(li, val); } - } - when (DType.Float64, DType.UInt64) { - var l = toSymEntry(left,real, nd); - var val = value.getUIntValue(); - select op { - when "+=" { l.a += val; } - when "-=" { l.a -= val; } - when "*=" { l.a *= val; } - when "//=" { - ref la = l.a; - [li in la] li = floorDivisionHelper(li, val); - }//floordiv - when "**=" { - l.a **= val; - } - when "%=" { - ref la = l.a; - [li in la] li = modHelper(li, val); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + when "**=" { l.a **= val; } + when "%=" { + ref la = l.a; + [li in la] li = modHelper(li, val); } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Float64) { - var l = toSymEntry(left,real, nd); - var val = value.getRealValue(); - select op { - when "+=" {l.a += val;} - when "-=" {l.a -= val;} - when "*=" {l.a *= val;} - when "/=" {l.a /= val;}//truediv - when "//=" { //floordiv - ref la = l.a; - [li in la] li = floorDivisionHelper(li, val); - } - when "**=" { l.a **= val; } - when "%=" { - ref la = l.a; - [li in la] li = modHelper(li, val); - } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == uint { + select op { + when "+=" { l.a += val; } + when "-=" { l.a -= val; } + when "*=" { l.a *= val; } + when "//=" { + ref la = l.a; + [li in la] li = floorDivisionHelper(li, val); + }//floordiv + when "**=" { + l.a **= val; + } + when "%=" { + ref la = l.a; + [li in la] li = modHelper(li, val); } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.Bool) { - var l = toSymEntry(left, real, nd); - var val = value.getBoolValue(); - select op { - when "+=" {l.a += val:real;} - when "-=" {l.a -= val:real;} - when "*=" {l.a *= val:real;} - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } + } + else if binop_dtype_a == real && binop_dtype_b == real { + select op { + when "+=" {l.a += val;} + when "-=" {l.a -= val;} + when "*=" {l.a *= val;} + when "/=" {l.a /= val;}//truediv + when "//=" { //floordiv + ref la = l.a; + [li in la] li = floorDivisionHelper(li, val); + } + when "**=" { l.a **= val; } + when "%=" { + ref la = l.a; + [li in la] li = modHelper(li, val); } + otherwise do return MsgTuple.error(nie); } - when (DType.Float64, DType.BigInt) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + else if binop_dtype_a == real && binop_dtype_b == bool { + select op { + when "+=" {l.a += val:real;} + when "-=" {l.a -= val:real;} + when "*=" {l.a *= val:real;} + otherwise do return MsgTuple.error(nie); } - when (DType.BigInt, DType.Int64) { - var l = toSymEntry(left,bigint, nd); - var val = value.getIntValue(); - ref la = l.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li += local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + else if binop_dtype_a == bigint && binop_dtype_b == int { + ref la = l.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li += local_val; + if has_max_bits { + li &= local_max_size; } - when "-=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li -= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li -= local_val; + if has_max_bits { + li &= local_max_size; } - when "*=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li *= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li *= local_val; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - li /= local_val; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "//=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + li /= local_val; } - when "%=" { - // we can't use li %= val because this can result in negatives - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - mod(li, li, local_val); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + else { + li = 0:bigint; } - when "**=" { - if val<0 { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall li in la with (var local_val = val, var local_max_size = max_size) { - powMod(li, li, local_val, local_max_size + 1); - } - } - else { - forall li in la with (var local_val = val) { - li **= local_val:uint; - } - } + if has_max_bits { + li &= local_max_size; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + when "%=" { + // we can't use li %= val because this can result in negatives + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + mod(li, li, local_val); + } + else { + li = 0:bigint; + } + if has_max_bits { + li &= local_max_size; } } - } - when (DType.BigInt, DType.UInt64) { - var l = toSymEntry(left,bigint, nd); - var val = value.getUIntValue(); - ref la = l.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; + } + when "**=" { + if val<0 { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li += local_val; - if has_max_bits { - li &= local_max_size; - } - } + forall li in la with (var local_val = val, var local_max_size = max_size) { + powMod(li, li, local_val, local_max_size + 1); } - when "-=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li -= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + else { + forall li in la with (var local_val = val) { + li **= local_val:uint; } - when "*=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li *= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + otherwise do return MsgTuple.error(nie); + } + } + else if binop_dtype_a == bigint && binop_dtype_b == uint { + ref la = l.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li += local_val; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - li /= local_val; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li -= local_val; + if has_max_bits { + li &= local_max_size; } - when "%=" { - // we can't use li %= val because this can result in negatives - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - mod(li, li, local_val); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li *= local_val; + if has_max_bits { + li &= local_max_size; } - when "**=" { - if val<0 { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall li in la with (var local_val = val, var local_max_size = max_size) { - powMod(li, li, local_val, local_max_size + 1); - } - } - else { - forall li in la with (var local_val = val) { - li **= local_val:uint; - } - } + } + } + when "//=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + li /= local_val; + } + else { + li = 0:bigint; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + if has_max_bits { + li &= local_max_size; } } - } - when (DType.BigInt, DType.Float64) { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } - when (DType.BigInt, DType.Bool) { - var l = toSymEntry(left, bigint, nd); - var val = value.getBoolValue(); - ref la = l.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; + } + when "%=" { + // we can't use li %= val because this can result in negatives + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + mod(li, li, local_val); + } + else { + li = 0:bigint; + } + if has_max_bits { + li &= local_max_size; + } + } + } + when "**=" { + if val<0 { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - // TODO change once we can cast directly from bool to bigint - when "+=" { - forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { - li += local_val; - if has_max_bits { - li &= local_max_size; - } - } + forall li in la with (var local_val = val, var local_max_size = max_size) { + powMod(li, li, local_val, local_max_size + 1); } - when "-=" { - forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { - li -= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + else { + forall li in la with (var local_val = val) { + li **= local_val:uint; } - when "*=" { - forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { - li *= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + otherwise do return MsgTuple.error(nie); + } + } + else if binop_dtype_a == bigint && binop_dtype_b == bool { + ref la = l.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + // TODO change once we can cast directly from bool to bigint + when "+=" { + forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { + li += local_val; + if has_max_bits { + li &= local_max_size; + } + } + } + when "-=" { + forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { + li -= local_val; + if has_max_bits { + li &= local_max_size; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + when "*=" { + forall li in la with (var local_val = val:int:bigint, var local_max_size = max_size) { + li *= local_val; + if has_max_bits { + li &= local_max_size; } } + } + otherwise do return MsgTuple.error(nie); } - when (DType.BigInt, DType.BigInt) { - var l = toSymEntry(left,bigint, nd); - var val = value.getBigIntValue(); - ref la = l.a; - var max_bits = l.max_bits; - var max_size = 1:bigint; - var has_max_bits = max_bits != -1; - if has_max_bits { - max_size <<= max_bits; - max_size -= 1; - } - select op { - when "+=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li += local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + else if binop_dtype_a == bigint && binop_dtype_b == bigint { + ref la = l.a; + var max_bits = l.max_bits; + var max_size = 1:bigint; + var has_max_bits = max_bits != -1; + if has_max_bits { + max_size <<= max_bits; + max_size -= 1; + } + select op { + when "+=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li += local_val; + if has_max_bits { + li &= local_max_size; } - when "-=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li -= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "-=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li -= local_val; + if has_max_bits { + li &= local_max_size; } - when "*=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - li *= local_val; - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "*=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + li *= local_val; + if has_max_bits { + li &= local_max_size; } - when "//=" { - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - li /= local_val; - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + } + } + when "//=" { + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + li /= local_val; } - when "%=" { - // we can't use li %= val because this can result in negatives - forall li in la with (var local_val = val, var local_max_size = max_size) { - if local_val != 0 { - mod(li, li, local_val); - } - else { - li = 0:bigint; - } - if has_max_bits { - li &= local_max_size; - } - } + else { + li = 0:bigint; } - when "**=" { - if val<0 { - throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); - } - if has_max_bits { - forall li in la with (var local_val = val, var local_max_size = max_size) { - powMod(li, li, local_val, local_max_size + 1); - } - } - else { - forall li in la with (var local_val = val) { - li **= local_val:uint; - } - } + if has_max_bits { + li &= local_max_size; } - otherwise { - var errorMsg = notImplementedError(pn,left.dtype,op,dtype); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); + } + } + when "%=" { + // we can't use li %= val because this can result in negatives + forall li in la with (var local_val = val, var local_max_size = max_size) { + if local_val != 0 { + mod(li, li, local_val); + } + else { + li = 0:bigint; + } + if has_max_bits { + li &= local_max_size; } } + } + when "**=" { + if val<0 { + throw new Error("Attempt to exponentiate base of type BigInt to negative exponent"); + } + if has_max_bits { + forall li in la with (var local_val = val, var local_max_size = max_size) { + powMod(li, li, local_val, local_max_size + 1); + } + } + else { + forall li in la with (var local_val = val) { + li **= local_val:uint; + } + } + } + otherwise do return MsgTuple.error(nie); } - otherwise { - var errorMsg = unrecognizedTypeError(pn, - "("+dtype2str(left.dtype)+","+dtype2str(dtype)+")"); - omLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); - return new MsgTuple(errorMsg, MsgType.ERROR); - } } - repMsg = "opeqvs success"; - omLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); - return new MsgTuple(repMsg, MsgType.NORMAL); + else { + return MsgTuple.error(nie); + } + return MsgTuple.success(); } } diff --git a/src/registry/Commands.chpl b/src/registry/Commands.chpl index 8c01a3b726..f3842289b4 100644 --- a/src/registry/Commands.chpl +++ b/src/registry/Commands.chpl @@ -1787,6 +1787,306 @@ proc ark_binopvs_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: return OperatorMsg.binopvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); registerFunction('binopvs', ark_binopvs_bigint_bigint_1, 'OperatorMsg', 228); +proc ark_binopsv_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('binopsv', ark_binopsv_int_int_1, 'OperatorMsg', 415); + +proc ark_binopsv_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('binopsv', ark_binopsv_int_uint_1, 'OperatorMsg', 415); + +proc ark_binopsv_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('binopsv', ark_binopsv_int_real_1, 'OperatorMsg', 415); + +proc ark_binopsv_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('binopsv', ark_binopsv_int_bool_1, 'OperatorMsg', 415); + +proc ark_binopsv_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopsv', ark_binopsv_int_bigint_1, 'OperatorMsg', 415); + +proc ark_binopsv_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('binopsv', ark_binopsv_uint_int_1, 'OperatorMsg', 415); + +proc ark_binopsv_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopsv', ark_binopsv_uint_uint_1, 'OperatorMsg', 415); + +proc ark_binopsv_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('binopsv', ark_binopsv_uint_real_1, 'OperatorMsg', 415); + +proc ark_binopsv_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopsv', ark_binopsv_uint_bool_1, 'OperatorMsg', 415); + +proc ark_binopsv_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopsv', ark_binopsv_uint_bigint_1, 'OperatorMsg', 415); + +proc ark_binopsv_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('binopsv', ark_binopsv_real_int_1, 'OperatorMsg', 415); + +proc ark_binopsv_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('binopsv', ark_binopsv_real_uint_1, 'OperatorMsg', 415); + +proc ark_binopsv_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('binopsv', ark_binopsv_real_real_1, 'OperatorMsg', 415); + +proc ark_binopsv_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('binopsv', ark_binopsv_real_bool_1, 'OperatorMsg', 415); + +proc ark_binopsv_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopsv', ark_binopsv_real_bigint_1, 'OperatorMsg', 415); + +proc ark_binopsv_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('binopsv', ark_binopsv_bool_int_1, 'OperatorMsg', 415); + +proc ark_binopsv_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('binopsv', ark_binopsv_bool_uint_1, 'OperatorMsg', 415); + +proc ark_binopsv_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('binopsv', ark_binopsv_bool_real_1, 'OperatorMsg', 415); + +proc ark_binopsv_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('binopsv', ark_binopsv_bool_bool_1, 'OperatorMsg', 415); + +proc ark_binopsv_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopsv', ark_binopsv_bool_bigint_1, 'OperatorMsg', 415); + +proc ark_binopsv_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('binopsv', ark_binopsv_bigint_int_1, 'OperatorMsg', 415); + +proc ark_binopsv_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('binopsv', ark_binopsv_bigint_uint_1, 'OperatorMsg', 415); + +proc ark_binopsv_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('binopsv', ark_binopsv_bigint_real_1, 'OperatorMsg', 415); + +proc ark_binopsv_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('binopsv', ark_binopsv_bigint_bool_1, 'OperatorMsg', 415); + +proc ark_binopsv_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.binopsv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('binopsv', ark_binopsv_bigint_bigint_1, 'OperatorMsg', 415); + +proc ark_opeqvv_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_int_int_1, 'OperatorMsg', 599); + +proc ark_opeqvv_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_int_uint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_int_real_1, 'OperatorMsg', 599); + +proc ark_opeqvv_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_int_bool_1, 'OperatorMsg', 599); + +proc ark_opeqvv_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_int_bigint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_uint_int_1, 'OperatorMsg', 599); + +proc ark_opeqvv_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_uint_uint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_uint_real_1, 'OperatorMsg', 599); + +proc ark_opeqvv_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_uint_bool_1, 'OperatorMsg', 599); + +proc ark_opeqvv_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_uint_bigint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_real_int_1, 'OperatorMsg', 599); + +proc ark_opeqvv_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_real_uint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_real_real_1, 'OperatorMsg', 599); + +proc ark_opeqvv_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_real_bool_1, 'OperatorMsg', 599); + +proc ark_opeqvv_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_real_bigint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bool_int_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bool_uint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bool_real_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bool_bool_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bool_bigint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bigint_int_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bigint_uint_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bigint_real_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bigint_bool_1, 'OperatorMsg', 599); + +proc ark_opeqvv_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvv(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvv', ark_opeqvv_bigint_bigint_1, 'OperatorMsg', 599); + +proc ark_opeqvs_int_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_int_int_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_int_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_int_uint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_int_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_int_real_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_int_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_int_bool_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_int_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=int, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_int_bigint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_uint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_uint_int_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_uint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_uint_uint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_uint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_uint_real_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_uint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_uint_bool_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_uint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=uint, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_uint_bigint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_real_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_real_int_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_real_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_real_uint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_real_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_real_real_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_real_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_real_bool_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_real_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=real, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_real_bigint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bool_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bool_int_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bool_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bool_uint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bool_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bool_real_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bool_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bool_bool_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bool_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bool, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bool_bigint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bigint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=int, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bigint_int_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bigint_uint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=uint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bigint_uint_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bigint_real_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=real, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bigint_real_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bigint_bool_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bool, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bigint_bool_1, 'OperatorMsg', 1066); + +proc ark_opeqvs_bigint_bigint_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do + return OperatorMsg.opeqvs(cmd, msgArgs, st, binop_dtype_a=bigint, binop_dtype_b=bigint, array_nd=1); +registerFunction('opeqvs', ark_opeqvs_bigint_bigint_1, 'OperatorMsg', 1066); + import RandMsg; proc ark_randint_int_1(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws do From 2e2baaaa21696614ba196feaf1afc981fcb3d507 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Wed, 18 Sep 2024 15:04:10 -0600 Subject: [PATCH 06/10] remove unused 'rname' variables from OperatorMsg Signed-off-by: Jeremiah Corrado --- src/BinOp.chpl | 3 --- src/OperatorMsg.chpl | 8 ++------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 92703d1f6a..910de558f5 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -62,9 +62,6 @@ module BinOp :arg op: string representation of binary operation to execute :type op: string - :arg rname: name of the `e` in the symbol table - :type rname: string - :arg pn: routine name of callsite function :type pn: string diff --git a/src/OperatorMsg.chpl b/src/OperatorMsg.chpl index 036b965d7c..1ced1ac433 100644 --- a/src/OperatorMsg.chpl +++ b/src/OperatorMsg.chpl @@ -27,7 +27,7 @@ module OperatorMsg Parse and respond to binopvv message. vv == vector op vector - :arg reqMsg: request containing (cmd,op,aname,bname,rname) + :arg reqMsg: request containing (cmd,op,aname,bname) :type reqMsg: string :arg st: SymTab to act on @@ -71,8 +71,6 @@ module OperatorMsg realOps.add("/"); realOps.add("//"); - const rname = st.nextName(); - if binop_dtype_a == int && binop_dtype_b == int { if boolOps.contains(op) { return doBinOpvv(l, r, bool, op, pn, st); @@ -240,8 +238,6 @@ module OperatorMsg "cmd: %? op: %? left pdarray: %? scalar: %?".format( cmd,op,st.attrib(msgArgs['a'].val), val)); - const rname = st.nextName(); - use Set; // This boolOps set is a filter to determine the output type for the operation. // All operations that involve one of these operations result in a `bool` symbol @@ -1053,7 +1049,7 @@ module OperatorMsg Parse and respond to opeqvs message. vector op= scalar - :arg reqMsg: request containing (cmd,op,aname,bname,rname) + :arg reqMsg: request containing (cmd,op,aname,bname) :type reqMsg: string :arg st: SymTab to act on From 0488a07740ce6a44804269d32527a4ed8f9073e6 Mon Sep 17 00:00:00 2001 From: Jeremiah Corrado Date: Fri, 20 Sep 2024 09:54:32 -0600 Subject: [PATCH 07/10] add notes about syntax-highlighter workaround to BinOP Signed-off-by: Jeremiah Corrado --- src/BinOp.chpl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 910de558f5..001c0d8283 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -190,7 +190,7 @@ module BinOp ref ra = r.a; [(ei,li,ri) in zip(ea,la,ra)] ei = if ri != 0 then li/ri else 0; } - when "%" { // " modulo + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; ref ra = r.a; @@ -320,7 +320,7 @@ module BinOp when "**" { e= l.a**r.a; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; ref ra = r.a; @@ -352,7 +352,7 @@ module BinOp when "**" { e= l.a:real**r.a:real; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; ref ra = r.a; @@ -552,7 +552,7 @@ module BinOp ref la = l.a; [(ei,li) in zip(ea,la)] ei = if val != 0 then li/val else 0; } - when "%" { // modulo " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = if val != 0 then li%val else 0; @@ -673,7 +673,7 @@ module BinOp when "**" { e= l.a**val; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = modHelper(li, val); @@ -724,7 +724,7 @@ module BinOp when "**" { e= l.a: real**val: real; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref la = l.a; [(ei,li) in zip(ea,la)] ei = modHelper(li, val); @@ -897,7 +897,7 @@ module BinOp ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val/ri else 0; } - when "%" { // modulo " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = if ri != 0 then val%ri else 0; @@ -1003,7 +1003,7 @@ module BinOp when "**" { e= val**r.a; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = modHelper(val:real, ri); @@ -1054,7 +1054,7 @@ module BinOp when "**" { e= val:real**r.a:real; } - when "%" { // " + when "%" { // modulo " <- quote is workaround for syntax highlighter bug ref ea = e; ref ra = r.a; [(ei,ri) in zip(ea,ra)] ei = modHelper(val:real, ri); @@ -1271,7 +1271,7 @@ module BinOp } visted = true; } - when "%" { // modulo + when "%" { // modulo " <- quote is workaround for syntax highlighter bug // we only do in place mod when ri != 0, tmp will be 0 in other locations // we can't use ei = li % ri because this can result in negatives forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { @@ -1515,7 +1515,7 @@ module BinOp } visted = true; } - when "%" { // modulo + when "%" { // modulo " <- quote is workaround for syntax highlighter bug // we only do in place mod when val != 0, tmp will be 0 in other locations // we can't use ei = li % val because this can result in negatives forall t in tmp with (var local_val = val, var local_max_size = max_size) { @@ -1779,7 +1779,7 @@ module BinOp } visted = true; } - when "%" { // modulo + when "%" { // modulo " <- quote is workaround for syntax highlighter bug forall (t, ri) in zip(tmp, ra) with (var local_max_size = max_size) { if ri != 0 { mod(t, t, ri); From ea312d38a407a5f1e9d2e7264c5c94e712890da6 Mon Sep 17 00:00:00 2001 From: jeremiah-corrado <62707311+jeremiah-corrado@users.noreply.github.com> Date: Fri, 27 Sep 2024 09:31:15 -0600 Subject: [PATCH 08/10] Update src/BinOp.chpl Co-authored-by: tess <48131946+stress-tess@users.noreply.github.com> --- src/BinOp.chpl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 001c0d8283..4a65e22ecb 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -269,7 +269,6 @@ module BinOp } else if (l.etype == uint && r.etype == int) || (l.etype == int && r.etype == uint) { - writeln("correct dispatch... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); select op { when "+" { e = l.a:real + r.a:real; From 1e46acba4ed0c4d17f938438a547426905d6e59b Mon Sep 17 00:00:00 2001 From: jeremiah-corrado <62707311+jeremiah-corrado@users.noreply.github.com> Date: Fri, 27 Sep 2024 09:31:21 -0600 Subject: [PATCH 09/10] Update src/BinOp.chpl Co-authored-by: tess <48131946+stress-tess@users.noreply.github.com> --- src/BinOp.chpl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 4a65e22ecb..67ba1aa230 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -286,7 +286,6 @@ module BinOp [(ei,li,ri) in zip(ea,la,ra)] ei = floorDivisionHelper(li, ri); } otherwise { - writeln("wtf... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); return MsgTuple.error(nie); } } From 55bbd03386862421078cc82a08980fefcbec2387 Mon Sep 17 00:00:00 2001 From: jeremiah-corrado <62707311+jeremiah-corrado@users.noreply.github.com> Date: Fri, 27 Sep 2024 09:31:26 -0600 Subject: [PATCH 10/10] Update src/BinOp.chpl Co-authored-by: tess <48131946+stress-tess@users.noreply.github.com> --- src/BinOp.chpl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/BinOp.chpl b/src/BinOp.chpl index 67ba1aa230..71d84d2a99 100644 --- a/src/BinOp.chpl +++ b/src/BinOp.chpl @@ -289,7 +289,6 @@ module BinOp return MsgTuple.error(nie); } } - writeln("returning... ", op, " ", l.etype: string, " ", r.etype: string, " ", etype: string); return st.insert(new shared SymEntry(e)); } // If either RHS or LHS type is real, the same operations are supported and the