From e3b44d2240d7ce98bf66b0c553bbe36202e509de Mon Sep 17 00:00:00 2001 From: Mathieu Ancelin Date: Wed, 10 Apr 2024 17:56:36 +0200 Subject: [PATCH] Avoid infinite retry loop when fetching bad binary --- .../io/otoroshi/wasm4s/impl/runtimev2.scala | 28 ++++++++---- .../wasm4s/scaladsl/integration.scala | 2 + .../io/otoroshi/wasm4s/scaladsl/wasm.scala | 44 ++++++++++++++++--- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/main/scala/io/otoroshi/wasm4s/impl/runtimev2.scala b/src/main/scala/io/otoroshi/wasm4s/impl/runtimev2.scala index 4d02e14..34c23a5 100644 --- a/src/main/scala/io/otoroshi/wasm4s/impl/runtimev2.scala +++ b/src/main/scala/io/otoroshi/wasm4s/impl/runtimev2.scala @@ -3,16 +3,14 @@ package io.otoroshi.wasm4s.impl import akka.stream.OverflowStrategy import akka.stream.scaladsl._ import com.codahale.metrics.UniformReservoir -import io.otoroshi.wasm4s.scaladsl.CacheableWasmScript.CachedWasmScript import io.otoroshi.wasm4s.scaladsl._ import io.otoroshi.wasm4s.scaladsl.opa._ import io.otoroshi.wasm4s.scaladsl.implicits._ -import io.otoroshi.wasm4s.impl.WasmVmPoolImpl.logger import org.extism.sdk.{HostFunction, HostUserData, Plugin} import org.extism.sdk.manifest.{Manifest, MemoryOptions} import org.extism.sdk.wasm.WasmSourceResolver import org.extism.sdk.wasmotoroshi._ -import play.api.Logger +import org.joda.time.DateTime import play.api.libs.json._ import java.util.concurrent.ConcurrentLinkedQueue @@ -256,7 +254,6 @@ case class WasmVmPoolAction(promise: Promise[WasmVmImpl], options: WasmVmInitOpt object WasmVmPoolImpl { - private[wasm4s] val logger = Logger("otoroshi-wasm-vm-pool") private val instances = new TrieMap[String, WasmVmPoolImpl]() def allInstances(): Map[String, WasmVmPoolImpl] = instances.synchronized { @@ -277,7 +274,7 @@ object WasmVmPoolImpl { class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration], maxCallsBetweenUpdates: Int = 100000, val ic: WasmIntegrationContext) extends WasmVmPool { - WasmVmPoolImpl.logger.debug("new WasmVmPool") + ic.logger.trace("new WasmVmPool") private val counter = new AtomicInteger(-1) private[wasm4s] val availableVms = new ConcurrentLinkedQueue[WasmVmImpl]() @@ -314,6 +311,12 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration .andThen { case _ => priorityQueue.offer(action) }(ic.executionContext) + } else if (wcfg.source.isFailed()(ic)) { + val until = wcfg.source.getFailedFromCache()(ic).get.until + if (until < time) { + wcfg.source.removeFromCache()(ic) + } + action.fail(new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}")) } else { // try to self refresh cache if more call than or time elapsed if (ic.selfRefreshingPools && (((time - lastCacheUpdateTime.get()) > ic.wasmCacheTtl) || (lastCacheUpdateCalls.get() > maxCallsBetweenUpdates))) { @@ -329,7 +332,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration // then we check if the underlying wasmcode + config has not changed since last time if (changed) { // if so, we destroy all current vms and recreate a new one - WasmVmPoolImpl.logger.warn("plugin has changed, destroying old instances") + ic.logger.warn("plugin has changed, destroying old instances") destroyCurrentVms() createVm(wcfg, action.options) } @@ -368,11 +371,15 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration private def createVm(config: WasmConfiguration, options: WasmVmInitOptions): Unit = synchronized { if (creatingRef.compareAndSet(false, true)) { val index = counter.incrementAndGet() - WasmVmPoolImpl.logger.debug(s"creating vm: ${index}") + ic.logger.debug(s"creating vm: ${index}") + if (config.source.isFailed()(ic)) { + creatingRef.compareAndSet(true, false) + return + } if (!config.source.isCached()(ic)) { // this part should never happen anymore, but just in case - WasmVmPoolImpl.logger.warn("fetching missing source") + ic.logger.warn("fetching missing source") Await.result(config.source.getWasm()(ic, ic.executionContext), 30.seconds) } lastPluginVersion.set(computeHash(config, config.source.cacheKey, ic.wasmScriptCache)) @@ -381,6 +388,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration val wasm = cache(key) match { case CacheableWasmScript.CachedWasmScript(script, _) => script case CacheableWasmScript.FetchingCachedWasmScript(_, script) => script + case CacheableWasmScript.FailedFetch(_, until) => throw new RuntimeException(s"accessing wasm binary was impossible. will retry after ${new DateTime(until).toString()}") case _ => throw new RuntimeException("unable to get wasm source from cache. this should not happen !") } // val hash = wasm.sha256 @@ -488,7 +496,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration // destroy all vms and clear everything in order to destroy the current pool private[wasm4s] def destroyCurrentVms(): Unit = availableVms.synchronized { - WasmVmPoolImpl.logger.info("destroying all vms") + ic.logger.info("destroying all vms") availableVms.asScala.foreach(_.destroy()) availableVms.clear() inUseVms.clear() @@ -509,6 +517,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration .map { case CacheableWasmScript.CachedWasmScript(wasm, _) => wasm.sha512 case CacheableWasmScript.FetchingCachedWasmScript(_, wasm) => wasm.sha512 + case CacheableWasmScript.FailedFetch(_, _) => "failed" case _ => "fetching" } .getOrElse("null") @@ -532,6 +541,7 @@ class WasmVmPoolImpl(stableId: => String, optConfig: => Option[WasmConfiguration val currentHash = computeHash(config, key, cache) oldHash != currentHash } + case Some(CacheableWasmScript.FailedFetch(_, _)) => false case _ => false } } diff --git a/src/main/scala/io/otoroshi/wasm4s/scaladsl/integration.scala b/src/main/scala/io/otoroshi/wasm4s/scaladsl/integration.scala index cbb6178..244c77a 100644 --- a/src/main/scala/io/otoroshi/wasm4s/scaladsl/integration.scala +++ b/src/main/scala/io/otoroshi/wasm4s/scaladsl/integration.scala @@ -26,6 +26,7 @@ trait WasmIntegrationContext { def materializer: Materializer def executionContext: ExecutionContext + def wasmFetchRetryAfterErrorDuration: FiniteDuration = 5.seconds def wasmCacheTtl: Long def wasmQueueBufferSize: Int def selfRefreshingPools: Boolean @@ -169,6 +170,7 @@ class WasmIntegration(ic: WasmIntegrationContext) { val sources = (pluginSources ++ inlineSources).distinct.map(s => (s.cacheKey, s)).toMap val now = System.currentTimeMillis() ic.wasmScriptCache.toSeq.foreach { + case (key, CacheableWasmScript.FailedFetch(_, until)) if now > until => ic.wasmScriptCache.remove(key) case (key, CacheableWasmScript.CachedWasmScript(_, createAt)) if (createAt + (ic.wasmCacheTtl * 2)) < now => { // 2 times should be enough sources.get(key) match { case Some(_) => () diff --git a/src/main/scala/io/otoroshi/wasm4s/scaladsl/wasm.scala b/src/main/scala/io/otoroshi/wasm4s/scaladsl/wasm.scala index b3bb78e..8b532fb 100644 --- a/src/main/scala/io/otoroshi/wasm4s/scaladsl/wasm.scala +++ b/src/main/scala/io/otoroshi/wasm4s/scaladsl/wasm.scala @@ -1,5 +1,6 @@ package io.otoroshi.wasm4s.scaladsl +import akka.http.scaladsl.util.FastFuture import akka.stream.scaladsl.StreamConverters import akka.util.ByteString import io.otoroshi.wasm4s.impl.OPAWasmVm @@ -7,6 +8,7 @@ import io.otoroshi.wasm4s.scaladsl.implicits._ import io.otoroshi.wasm4s.scaladsl.security._ import org.extism.sdk.{HostFunction, HostUserData, Plugin} import org.extism.sdk.wasmotoroshi._ +import org.joda.time.DateTime import play.api.libs.json._ import java.nio.file.{Files, Paths} @@ -321,11 +323,27 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o def json: JsValue = WasmSource.format.writes(this) def cacheKey = s"${kind.name.toLowerCase}://${path}" def getConfig()(implicit ic: WasmIntegrationContext, ec: ExecutionContext): Future[Option[WasmConfiguration]] = kind.getConfig(path, opts) + def removeFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript] = ic.wasmScriptCache.remove(cacheKey) + def getFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript] = ic.wasmScriptCache.get(cacheKey) + def getFailedFromCache()(implicit ic: WasmIntegrationContext): Option[CacheableWasmScript.FailedFetch] = { + getFromCache() match { + case Some(i @ CacheableWasmScript.FailedFetch(_, _)) => i.some + case _ => None + } + } + def isFailed()(implicit ic: WasmIntegrationContext): Boolean = { + val cache = ic.wasmScriptCache + cache.get(cacheKey) match { + case Some(CacheableWasmScript.FailedFetch(_, _)) => true + case _ => false + } + } def isCached()(implicit ic: WasmIntegrationContext): Boolean = { val cache = ic.wasmScriptCache cache.get(cacheKey) match { case Some(CacheableWasmScript.CachedWasmScript(_, _)) => true case Some(CacheableWasmScript.FetchingCachedWasmScript(_, _)) => true + case Some(CacheableWasmScript.FailedFetch(_, _)) => true case _ => false } } @@ -343,7 +361,9 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o case Left(err) => if (ic.logger.isErrorEnabled) ic.logger.error(s"[WasmSource] error while wasm fetch at ${path}: ${err}") maybeAlready match { - case None => cache.remove(cacheKey) + case None => + cache.remove(cacheKey) + cache.put(cacheKey, CacheableWasmScript.FailedFetch(System.currentTimeMillis(), System.currentTimeMillis() + ic.wasmFetchRetryAfterErrorDuration.toMillis)) case Some(s) => // put if back and wait for better times ??? if (ic.logger.isWarnEnabled) ic.logger.warn(s"[WasmSource] using old version of ${path} because of fetch error: ${err}") @@ -362,7 +382,9 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o case e => val err = Json.obj("error" -> s"error while getting wasm from source: ${e.getMessage}") maybeAlready match { - case None => cache.remove(cacheKey) + case None => + cache.remove(cacheKey) + cache.put(cacheKey, CacheableWasmScript.FailedFetch(System.currentTimeMillis(), System.currentTimeMillis() + ic.wasmFetchRetryAfterErrorDuration.toMillis)) case Some(s) => // put if back and wait ??? if (ic.logger.isWarnEnabled) ic.logger.warn(s"[WasmSource] using old version of ${path} because of recover error: ${err}") @@ -375,21 +397,28 @@ case class WasmSource(kind: WasmSourceKind, path: String, opts: JsValue = Json.o cache.get(cacheKey) match { case None => - if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm nothing in cache for ${path}") + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm nothing in cache for '${path}'") + fetchAndAddToCache(None) + case Some(CacheableWasmScript.FailedFetch(_, until)) if System.currentTimeMillis() > until => + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm has failed for '${path}'") fetchAndAddToCache(None) + Left(Json.obj("error" -> s"unable to access wasm binary. will retry after ${new DateTime(until).toString()}")).vfuture + case Some(CacheableWasmScript.FailedFetch(_, until)) => + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm has failed for '${path}'") + Left(Json.obj("error" -> s"unable to access wasm binary. will retry after ${new DateTime(until).toString()}")).vfuture case Some(CacheableWasmScript.FetchingWasmScript(fu)) => - if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching for ${path}") + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching for '${path}'") fu case Some(CacheableWasmScript.FetchingCachedWasmScript(_, script)) => - if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching already cached for ${path}") + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm fetching already cached for '${path}'") script.right.future case Some(CacheableWasmScript.CachedWasmScript(script, createAt)) if createAt + ic.wasmCacheTtl < System.currentTimeMillis => - if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm expired cache for ${path} - ${createAt} - ${ic.wasmCacheTtl} - ${createAt + ic.wasmCacheTtl} - ${System.currentTimeMillis}") + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm expired cache for '${path}' - ${createAt} - ${ic.wasmCacheTtl} - ${createAt + ic.wasmCacheTtl} - ${System.currentTimeMillis}") fetchAndAddToCache(script.some) script.right.vfuture case Some(CacheableWasmScript.CachedWasmScript(script, _)) => - if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm cached for ${path}") + if (ic.logger.isDebugEnabled) ic.logger.debug(s"[WasmSource] getWasm cached for '${path}'") script.right.vfuture } } @@ -532,6 +561,7 @@ object CacheableWasmScript { case class CachedWasmScript(script: ByteString, createAt: Long) extends CacheableWasmScript case class FetchingWasmScript(f: Future[Either[JsValue, ByteString]]) extends CacheableWasmScript case class FetchingCachedWasmScript(f: Future[Either[JsValue, ByteString]], script: ByteString) extends CacheableWasmScript + case class FailedFetch(createAt: Long, until: Long) extends CacheableWasmScript } case class WasmVmInitOptions(