Skip to content

Commit

Permalink
fixed bug in batched pipeline (#288)
Browse files Browse the repository at this point in the history
* fixed bug in batched pipeline

* Fixed bug in batched pipeline to handle quotes within strings

* Added more test cases in PipelineSpec

* scala-redis/issues/286 added test case demontrating that ZADD stores java string values properly. (#291)

* Remove commented code

Co-authored-by: Noah Zucker <nzucker@gmail.com>
  • Loading branch information
debasishg and noahlz authored Nov 2, 2021
1 parent 68c6f3f commit 83dbcfc
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 40 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ lazy val commonSettings: Seq[Setting[_]] = Seq(
organization := "net.debasishg",
version := "3.42-SNAPSHOT",
scalaVersion := "2.13.6",
crossScalaVersions := Seq("2.12.14", "2.11.12", "2.10.7"),
crossScalaVersions := Seq("2.13.6", "2.12.14", "2.11.12", "2.10.7"),

scalacOptions in Compile ++= Seq( "-unchecked", "-feature", "-language:postfixOps", "-deprecation" ),

Expand Down
18 changes: 9 additions & 9 deletions src/main/scala/com/redis/BaseOperations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ trait BaseOperations extends BaseApi {
send("KEYS", List(pattern))(asList)

override def time[A](implicit format: Format, parse: Parse[A]): Option[List[Option[A]]] =
send("TIME", false)(asList)
send("TIME")(asList)

@deprecated("use randomkey", "2.8")
def randkey[A](implicit parse: Parse[A]): Option[A] =
send("RANDOMKEY", false)(asBulk)
send("RANDOMKEY")(asBulk)

override def randomkey[A](implicit parse: Parse[A]): Option[A] =
send("RANDOMKEY", false)(asBulk)
send("RANDOMKEY")(asBulk)

override def rename(oldkey: Any, newkey: Any)(implicit format: Format): Boolean =
send("RENAME", List(oldkey, newkey))(asBoolean)
Expand All @@ -63,7 +63,7 @@ trait BaseOperations extends BaseApi {
send("RENAMENX", List(oldkey, newkey))(asBoolean)

override def dbsize: Option[Long] =
send("DBSIZE", false)(asLong)
send("DBSIZE")(asLong)

override def exists(key: Any)(implicit format: Format): Boolean =
send("EXISTS", List(key))(asBoolean)
Expand Down Expand Up @@ -104,16 +104,16 @@ trait BaseOperations extends BaseApi {
})

override def flushdb: Boolean =
send("FLUSHDB", false)(asBoolean)
send("FLUSHDB")(asBoolean)

override def flushall: Boolean =
send("FLUSHALL", false)(asBoolean)
send("FLUSHALL")(asBoolean)

override def move(key: Any, db: Int)(implicit format: Format): Boolean =
send("MOVE", List(key, db))(asBoolean)

override def quit: Boolean =
send("QUIT", false)(disconnect)
send("QUIT")(disconnect)

override def auth(secret: Any)(implicit format: Format): Boolean =
send("AUTH", List(secret))(asBoolean)
Expand All @@ -125,13 +125,13 @@ trait BaseOperations extends BaseApi {
send("SCAN", cursor :: ((x: List[Any]) => if (pattern == "*") x else "match" :: pattern :: x) (if (count == 10) Nil else List("count", count)))(asPair)

override def ping: Option[String] =
send("PING", false)(asString)
send("PING")(asString)

override def watch(key: Any, keys: Any*)(implicit format: Format): Boolean =
send("WATCH", key :: keys.toList)(asBoolean)

override def unwatch(): Boolean =
send("UNWATCH", false)(asBoolean)
send("UNWATCH")(asBoolean)

override def getConfig(key: Any = "*")(implicit format: Format): Option[Map[String, Option[String]]] =
send("CONFIG", List("GET", key))(asList).map { ls =>
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/com/redis/NodeOperations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ trait NodeOperations extends NodeApi {
self: Redis =>

override def save: Boolean =
send("SAVE", false)(asBoolean)
send("SAVE")(asBoolean)

override def bgsave: Boolean =
send("BGSAVE", false)(asBoolean)
send("BGSAVE")(asBoolean)

override def lastsave: Option[Long] =
send("LASTSAVE", false)(asLong)
send("LASTSAVE")(asLong)

override def shutdown: Boolean =
send("SHUTDOWN", false)(asBoolean)
send("SHUTDOWN")(asBoolean)

override def bgrewriteaof: Boolean =
send("BGREWRITEAOF", false)(asBoolean)
send("BGREWRITEAOF")(asBoolean)

override def info: Option[String] =
send("INFO", false)(asBulk)
send("INFO")(asBulk)

override def monitor: Boolean =
send("MONITOR", false)(asBoolean)
send("MONITOR")(asBoolean)

override def slaveof(options: Any): Boolean = options match {
case (h: String, p: Int) =>
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/redis/PubSub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ trait PubSub extends PubOperations { self: Redis =>
}

def pUnsubscribe(): Unit = {
send("PUNSUBSCRIBE", false)(())
send("PUNSUBSCRIBE")(())
}

def pUnsubscribe(channel: String, channels: String*): Unit = {
Expand All @@ -98,7 +98,7 @@ trait PubSub extends PubOperations { self: Redis =>
}

def unsubscribe(): Unit = {
val r = send("UNSUBSCRIBE", false)(())
val r = send("UNSUBSCRIBE")(())
pubSub = false
r
}
Expand Down
52 changes: 31 additions & 21 deletions src/main/scala/com/redis/RedisClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ object RedisClient {
)
.getOrElse(0)
}

case class CommandToSend(command: String, args: Seq[Array[Byte]])
}

import RedisClient._
abstract class Redis(batch: Mode) extends IO with Protocol {
var handlers: Vector[(String, () => Any)] = Vector.empty
var commandBuffer: StringBuffer = new StringBuffer
val crlf = "\r\n"

val commandBuffer = collection.mutable.ListBuffer.empty[CommandToSend]

def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try {
if (batch == BATCH) {
handlers :+= ((command, () => result))
commandBuffer.append((List(command) ++ args.toList).mkString(" ") ++ crlf)
commandBuffer += CommandToSend(command, args.map(format.apply))
null.asInstanceOf[A] // hack
} else {
write(Commands.multiBulk(command.getBytes("UTF-8") +: (args map (format.apply))))
Expand All @@ -53,26 +53,36 @@ abstract class Redis(batch: Mode) extends IO with Protocol {
else throw e
}

def send[A](command: String, submissionMode: Boolean = false)(result: => A): A = try {
def send[A](command: String)(result: => A): A = try {
if (batch == BATCH) {
if (!submissionMode) {
handlers :+= ((command, () => result))
commandBuffer.append(command ++ crlf)
null.asInstanceOf[A]
} else {
write(command.getBytes("UTF-8"))
result
}
handlers :+= ((command, () => result))
commandBuffer += CommandToSend(command, Seq.empty[Array[Byte]])
null.asInstanceOf[A]
} else {
write(Commands.multiBulk(List(command.getBytes("UTF-8"))))
result
}
} catch {
case e: RedisConnectionException =>
if (disconnect) send(command, submissionMode)(result)
if (disconnect) send(command)(result)
else throw e
case e: SocketException =>
if (disconnect) send(command)(result)
else throw e
}

def send[A](commands: List[CommandToSend])(result: => A): A = try {
val cs = commands.map { command =>
command.command.getBytes("UTF-8") +: command.args
}
write(Commands.multiMultiBulk(cs))
result
} catch {
case e: RedisConnectionException =>
if (disconnect) send(commands)(result)
else throw e
case e: SocketException =>
if (disconnect) send(command, submissionMode)(result)
if (disconnect) send(commands)(result)
else throw e
}

Expand Down Expand Up @@ -143,17 +153,17 @@ class RedisClient(override val host: String, override val port: Int,
* @see https://redis.io/commands/multi
*/
def pipeline(f: PipelineClient => Any): Option[List[Any]] = {
send("MULTI", false)(asString) // flush reply stream
send("MULTI")(asString) // flush reply stream
try {
val pipelineClient = new PipelineClient(this)
try {
f(pipelineClient)
} catch {
case e: Exception =>
send("DISCARD", false)(asString)
send("DISCARD")(asString)
throw e
}
send("EXEC", false)(asExec(pipelineClient.responseHandlers))
send("EXEC")(asExec(pipelineClient.responseHandlers))
} catch {
case e: RedisMultiExecException =>
None
Expand Down Expand Up @@ -226,9 +236,9 @@ class RedisClient(override val host: String, override val port: Int,
commands.foreach { command =>
command()
}
val r = send(commandBuffer.toString, true)(Some(handlers.map(_._2).map(_()).toList))
val r = send(commandBuffer.toList)(Some(handlers.map(_._2).map(_()).toList))
handlers = Vector.empty
commandBuffer.setLength(0)
commandBuffer.clear()
r
}

Expand All @@ -248,7 +258,7 @@ class RedisClient(override val host: String, override val port: Int,
receive(singleLineReply).map(Parse.parseDefault)
null.asInstanceOf[A] // ugh... gotta find a better way
}
override def send[A](command: String, submissionMode: Boolean = false)(result: => A): A = {
override def send[A](command: String)(result: => A): A = {
write(Commands.multiBulk(List(command.getBytes("UTF-8"))))
responseHandlers :+= (() => result)
receive(singleLineReply).map(Parse.parseDefault)
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/com/redis/RedisProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ private [redis] object Commands {
}
b.result()
}

def multiMultiBulk(commands: Seq[Seq[Array[Byte]]]): Array[Byte] = {
val no = commands.size
val b = new scala.collection.mutable.ArrayBuilder.ofByte
// b ++= "*%d".format(no).getBytes
// b ++= LS
commands.foreach { command =>
b ++= multiBulk(command)
b ++= LS
}
b.result()
}
}

import Commands._
Expand Down
Loading

0 comments on commit 83dbcfc

Please sign in to comment.