Skip to content

Commit

Permalink
Avoid infinite retry loop when fetching bad binary
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieuancelin committed Apr 10, 2024
1 parent cc4e51f commit e3b44d2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
28 changes: 19 additions & 9 deletions src/main/scala/io/otoroshi/wasm4s/impl/runtimev2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@ package io.otoroshi.wasm4s.impl
import akka.stream.OverflowStrategy
import akka.stream.scaladsl._
import com.codahale.metrics.UniformReservoir
import io.otoroshi.wasm4s.scaladsl.CacheableWasmScript.CachedWasmScript
import io.otoroshi.wasm4s.scaladsl._
import io.otoroshi.wasm4s.scaladsl.opa._
import io.otoroshi.wasm4s.scaladsl.implicits._
import io.otoroshi.wasm4s.impl.WasmVmPoolImpl.logger
import org.extism.sdk.{HostFunction, HostUserData, Plugin}
import org.extism.sdk.manifest.{Manifest, MemoryOptions}
import org.extism.sdk.wasm.WasmSourceResolver
import org.extism.sdk.wasmotoroshi._
import play.api.Logger
import org.joda.time.DateTime
import play.api.libs.json._

import java.util.concurrent.ConcurrentLinkedQueue
Expand Down Expand Up @@ -256,7 +254,6 @@ case class WasmVmPoolAction(promise: Promise[WasmVmImpl], options: WasmVmInitOpt

object WasmVmPoolImpl {

private[wasm4s] val logger = Logger("otoroshi-wasm-vm-pool")
private val instances = new TrieMap[String, WasmVmPoolImpl]()

def allInstances(): Map[String, WasmVmPoolImpl] = instances.synchronized {
Expand All @@ -277,7 +274,7 @@ object WasmVmPoolImpl {

class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration], maxCallsBetweenUpdates: Int = 100000, val ic: WasmIntegrationContext) extends WasmVmPool {

WasmVmPoolImpl.logger.debug("new WasmVmPool")
ic.logger.trace("new WasmVmPool")

private val counter = new AtomicInteger(-1)
private[wasm4s] val availableVms = new ConcurrentLinkedQueue[WasmVmImpl]()
Expand Down Expand Up @@ -314,6 +311,12 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
.andThen { case _ =>
priorityQueue.offer(action)
}(ic.executionContext)
} else if (wcfg.source.isFailed()(ic)) {
val until = wcfg.source.getFailedFromCache()(ic).get.until
if (until < time) {
wcfg.source.removeFromCache()(ic)
}
action.fail(new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}"))
} else {
// try to self refresh cache if more call than or time elapsed
if (ic.selfRefreshingPools && (((time - lastCacheUpdateTime.get()) > ic.wasmCacheTtl) || (lastCacheUpdateCalls.get() > maxCallsBetweenUpdates))) {
Expand All @@ -329,7 +332,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
// then we check if the underlying wasmcode + config has not changed since last time
if (changed) {
// if so, we destroy all current vms and recreate a new one
WasmVmPoolImpl.logger.warn("plugin has changed, destroying old instances")
ic.logger.warn("plugin has changed, destroying old instances")
destroyCurrentVms()
createVm(wcfg, action.options)
}
Expand Down Expand Up @@ -368,11 +371,15 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
private def createVm(config: WasmConfiguration, options: WasmVmInitOptions): Unit = synchronized {
if (creatingRef.compareAndSet(false, true)) {
val index = counter.incrementAndGet()
WasmVmPoolImpl.logger.debug(s"creating vm: ${index}")
ic.logger.debug(s"creating vm: ${index}")

if (config.source.isFailed()(ic)) {
creatingRef.compareAndSet(true, false)
return
}
if (!config.source.isCached()(ic)) {
// this part should never happen anymore, but just in case
WasmVmPoolImpl.logger.warn("fetching missing source")
ic.logger.warn("fetching missing source")
Await.result(config.source.getWasm()(ic, ic.executionContext), 30.seconds)
}
lastPluginVersion.set(computeHash(config, config.source.cacheKey, ic.wasmScriptCache))
Expand All @@ -381,6 +388,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
val wasm = cache(key) match {
case CacheableWasmScript.CachedWasmScript(script, _) => script
case CacheableWasmScript.FetchingCachedWasmScript(_, script) => script
case CacheableWasmScript.FailedFetch(_, until) => throw new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}")
case _ => throw new RuntimeException("unable to get wasm source from cache. this should not happen !")
}
// val hash = wasm.sha256
Expand Down Expand Up @@ -488,7 +496,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration

// destroy all vms and clear everything in order to destroy the current pool
private[wasm4s] def destroyCurrentVms(): Unit = availableVms.synchronized {
WasmVmPoolImpl.logger.info("destroying all vms")
ic.logger.info("destroying all vms")
availableVms.asScala.foreach(_.destroy())
availableVms.clear()
inUseVms.clear()
Expand All @@ -509,6 +517,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
.map {
case CacheableWasmScript.CachedWasmScript(wasm, _) => wasm.sha512
case CacheableWasmScript.FetchingCachedWasmScript(_, wasm) => wasm.sha512
case CacheableWasmScript.FailedFetch(_, _) => "failed"
case _ => "fetching"
}
.getOrElse("null")
Expand All @@ -532,6 +541,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration
val currentHash = computeHash(config, key, cache)
oldHash != currentHash
}
case Some(CacheableWasmScript.FailedFetch(_, _)) => false
case _ => false
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/io/otoroshi/wasm4s/scaladsl/integration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ trait WasmIntegrationContext {
def materializer: Materializer
def executionContext: ExecutionContext

def wasmFetchRetryAfterErrorDuration: FiniteDuration = 5.seconds
def wasmCacheTtl: Long
def wasmQueueBufferSize: Int
def selfRefreshingPools: Boolean
Expand Down Expand Up @@ -169,6 +170,7 @@ class WasmIntegration(ic: WasmIntegrationContext) {
val sources = (pluginSources ++ inlineSources).distinct.map(s => (s.cacheKey, s)).toMap
val now = System.currentTimeMillis()
ic.wasmScriptCache.toSeq.foreach {
case (key, CacheableWasmScript.FailedFetch(_, until)) if now > until => ic.wasmScriptCache.remove(key)
case (key, CacheableWasmScript.CachedWasmScript(_, createAt)) if (createAt + (ic.wasmCacheTtl * 2)) < now => { // 2 times should be enough
sources.get(key) match {
case Some(_) => ()
Expand Down
44 changes: 37 additions & 7 deletions src/main/scala/io/otoroshi/wasm4s/scaladsl/wasm.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.otoroshi.wasm4s.scaladsl

import akka.http.scaladsl.util.FastFuture
import akka.stream.scaladsl.StreamConverters
import akka.util.ByteString
import io.otoroshi.wasm4s.impl.OPAWasmVm
import io.otoroshi.wasm4s.scaladsl.implicits._
import io.otoroshi.wasm4s.scaladsl.security._
import org.extism.sdk.{HostFunction, HostUserData, Plugin}
import org.extism.sdk.wasmotoroshi._
import org.joda.time.DateTime
import play.api.libs.json._

import java.nio.file.{Files, Paths}
Expand Down Expand Up @@ -321,11 +323,27 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o
def json: JsValue = WasmSource.format.writes(this)
def cacheKey = s"${kind.name.toLowerCase}://${path}"
def getConfig()(implicit ic: WasmIntegrationContext, ec: ExecutionContext): Future[Option[WasmConfiguration]] = kind.getConfig(path, opts)
def removeFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript] = ic.wasmScriptCache.remove(cacheKey)
def getFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript] = ic.wasmScriptCache.get(cacheKey)
def getFailedFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript.FailedFetch] = {
getFromCache() match {
case Some(i @ CacheableWasmScript.FailedFetch(_, _)) => i.some
case _ => None
}
}
def isFailed()(implicit ic: WasmIntegrationContext): Boolean = {
val cache = ic.wasmScriptCache
cache.get(cacheKey) match {
case Some(CacheableWasmScript.FailedFetch(_, _)) => true
case _ => false
}
}
def isCached()(implicit ic: WasmIntegrationContext): Boolean = {
val cache = ic.wasmScriptCache
cache.get(cacheKey) match {
case Some(CacheableWasmScript.CachedWasmScript(_, _)) => true
case Some(CacheableWasmScript.FetchingCachedWasmScript(_, _)) => true
case Some(CacheableWasmScript.FailedFetch(_, _)) => true
case _ => false
}
}
Expand All @@ -343,7 +361,9 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o
case Left(err) =>
if (ic.logger.isErrorEnabled) ic.logger.error(s"[WasmSource] error while wasm fetch at ${path}: ${err}")
maybeAlready match {
case None => cache.remove(cacheKey)
case None =>
cache.remove(cacheKey)
cache.put(cacheKey, CacheableWasmScript.FailedFetch(System.currentTimeMillis(), System.currentTimeMillis() + ic.wasmFetchRetryAfterErrorDuration.toMillis))
case Some(s) =>
// put if back and wait for better times ???
if (ic.logger.isWarnEnabled) ic.logger.warn(s"[WasmSource] using old version of ${path} because of fetch error: ${err}")
Expand All @@ -362,7 +382,9 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o
case e =>
val err = Json.obj("error" -> s"error while getting wasm from source: ${e.getMessage}")
maybeAlready match {
case None => cache.remove(cacheKey)
case None =>
cache.remove(cacheKey)
cache.put(cacheKey, CacheableWasmScript.FailedFetch(System.currentTimeMillis(), System.currentTimeMillis() + ic.wasmFetchRetryAfterErrorDuration.toMillis))
case Some(s) =>
// put if back and wait ???
if (ic.logger.isWarnEnabled) ic.logger.warn(s"[WasmSource] using old version of ${path} because of recover error: ${err}")
Expand All @@ -375,21 +397,28 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o

cache.get(cacheKey) match {
case None =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm nothing in cache for ${path}")
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm nothing in cache for '${path}'")
fetchAndAddToCache(None)
case Some(CacheableWasmScript.FailedFetch(_, until)) if System.currentTimeMillis() > until =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm has failed for '${path}'")
fetchAndAddToCache(None)
Left(Json.obj("error" -> s"unable to access wasm binary. will retry after ${new DateTime(until).toString()}")).vfuture
case Some(CacheableWasmScript.FailedFetch(_, until)) =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm has failed for '${path}'")
Left(Json.obj("error" -> s"unable to access wasm binary. will retry after ${new DateTime(until).toString()}")).vfuture
case Some(CacheableWasmScript.FetchingWasmScript(fu)) =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching for ${path}")
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching for '${path}'")
fu
case Some(CacheableWasmScript.FetchingCachedWasmScript(_, script)) =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching already cached for ${path}")
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching already cached for '${path}'")
script.right.future
case Some(CacheableWasmScript.CachedWasmScript(script, createAt))
if createAt + ic.wasmCacheTtl < System.currentTimeMillis =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm expired cache for ${path} - ${createAt} - ${ic.wasmCacheTtl} - ${createAt + ic.wasmCacheTtl} - ${System.currentTimeMillis}")
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm expired cache for '${path}' - ${createAt} - ${ic.wasmCacheTtl} - ${createAt + ic.wasmCacheTtl} - ${System.currentTimeMillis}")
fetchAndAddToCache(script.some)
script.right.vfuture
case Some(CacheableWasmScript.CachedWasmScript(script, _)) =>
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm cached for ${path}")
if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm cached for '${path}'")
script.right.vfuture
}
}
Expand Down Expand Up @@ -532,6 +561,7 @@ object CacheableWasmScript {
case class CachedWasmScript(script: ByteString, createAt: Long) extends CacheableWasmScript
case class FetchingWasmScript(f: Future[Either[JsValue, ByteString]]) extends CacheableWasmScript
case class FetchingCachedWasmScript(f: Future[Either[JsValue, ByteString]], script: ByteString) extends CacheableWasmScript
case class FailedFetch(createAt: Long, until: Long) extends CacheableWasmScript
}

case class WasmVmInitOptions(
Expand Down

0 comments on commit e3b44d2

Please sign in to comment.