Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AbstractCurlBackend async friendly #2012

Merged
merged 4 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.monad.MonadError
import sttp.monad.syntax._

import scala.collection.immutable.Seq
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.scalanative.libc.stdio.{fclose, fopen, FILE}
import scala.scalanative.libc.stdlib._
Expand All @@ -25,33 +26,68 @@ import scala.scalanative.unsigned._
abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean) extends GenericBackend[F, Any] {
override implicit def monad: MonadError[F] = _monad

/** Given a [[CurlHandle]], perform the request and return a [[CurlCode]]. */
protected def performCurl(c: CurlHandle): F[CurlCode]

/** Same as [[performCurl]], but also checks and throws runtime exceptions on bad [[CurlCode]]s. */
private final def perform(c: CurlHandle) = performCurl(c).flatMap(lift)

type R = Any with Effect[F]

override def close(): F[Unit] = monad.unit(())

private var headers: CurlList = _
private var multiPartHeaders: Seq[CurlList] = Seq()
/** A request-specific context, with allocated zones and headers. */
private class Context() {
implicit val zone: Zone = Zone.open()
private val headers = ArrayBuffer[CurlList]()

/** Create a new Headers list that gets cleaned up when the context is destroyed. */
def transformHeaders(reqHeaders: Iterable[Header]): CurlList = {
val h = reqHeaders
.map(header => s"${header.name}: ${header.value}")
.foldLeft(new CurlList(null)) { case (acc, h) =>
new CurlList(acc.ptr.append(h))
}
headers += h
h
}

def close() = {
zone.close()
headers.foreach(l => if (l.ptr != null) l.ptr.free())
}
}

private object Context {

/** Create a new context and evaluates the body with it. Closes the context at the end. */
def evaluateUsing[T](body: Context => F[T]): F[T] = {
implicit val ctx = new Context()
body(ctx).ensure(monad.unit(ctx.close()))
}
}

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
adjustExceptions(request) {
unsafe.Zone { implicit z =>
def perform(implicit ctx: Context): F[Response[T]] = {
implicit val z = ctx.zone
val curl = CurlApi.init
if (verbose) {
curl.option(Verbose, parameter = true)
}
if (request.tags.nonEmpty) {
monad.error(new UnsupportedOperationException("Tags are not supported"))
return monad.error(new UnsupportedOperationException("Tags are not supported"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! :)

}
val reqHeaders = request.headers
if (reqHeaders.nonEmpty) {
reqHeaders.find(_.name == "Accept-Encoding").foreach(h => curl.option(AcceptEncoding, h.value))
request.body match {
val headers = request.body match {
case _: MultipartBody[_] =>
headers = transformHeaders(
ctx.transformHeaders(
reqHeaders :+ Header.contentType(MediaType.MultipartFormData)
)
case _ =>
headers = transformHeaders(reqHeaders)
ctx.transformHeaders(reqHeaders)
}
curl.option(HttpHeader, headers.ptr)
}
Expand All @@ -62,6 +98,8 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
case None => handleBase(request, curl, spaces)
}
}

Context.evaluateUsing(ctx => perform(ctx))
}

private def adjustExceptions[T](request: GenericRequest[_, _])(t: => F[T]): F[T] =
Expand All @@ -70,22 +108,21 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
)

private def handleBase[T](request: GenericRequest[T, R], curl: CurlHandle, spaces: CurlSpaces)(implicit
z: unsafe.Zone
ctx: Context
) = {
implicit val z = ctx.zone
curl.option(WriteFunction, AbstractCurlBackend.wdFunc)
curl.option(WriteData, spaces.bodyResp)
curl.option(TimeoutMs, request.options.readTimeout.toMillis)
curl.option(HeaderData, spaces.headersResp)
curl.option(Url, request.uri.toString)
setMethod(curl, request.method)
setRequestBody(curl, request.body)
monad.flatMap(lift(curl.perform)) { _ =>
monad.flatMap(perform(curl)) { _ =>
curl.info(ResponseCode, spaces.httpCode)
val responseBody = fromCString((!spaces.bodyResp)._1)
val responseHeaders_ = parseHeaders(fromCString((!spaces.headersResp)._1))
val httpCode = StatusCode((!spaces.httpCode).toInt)
if (headers.ptr != null) headers.ptr.free()
multiPartHeaders.foreach(_.ptr.free())
free((!spaces.bodyResp)._1)
free((!spaces.headersResp)._1)
free(spaces.bodyResp.asInstanceOf[Ptr[CSignedChar]])
Expand All @@ -112,19 +149,18 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
}

private def handleFile[T](request: GenericRequest[T, R], curl: CurlHandle, file: SttpFile, spaces: CurlSpaces)(
implicit z: unsafe.Zone
implicit ctx: Context
) = {
implicit val z = ctx.zone
val outputPath = file.toPath.toString
val outputFilePtr: Ptr[FILE] = fopen(toCString(outputPath), toCString("wb"))
curl.option(WriteData, outputFilePtr)
curl.option(Url, request.uri.toString)
setMethod(curl, request.method)
setRequestBody(curl, request.body)
monad.flatMap(lift(curl.perform)) { _ =>
monad.flatMap(perform(curl)) { _ =>
curl.info(ResponseCode, spaces.httpCode)
val httpCode = StatusCode((!spaces.httpCode).toInt)
if (headers.ptr != null) headers.ptr.free()
multiPartHeaders.foreach(_.ptr.free())
free(spaces.httpCode.asInstanceOf[Ptr[CSignedChar]])
fclose(outputFilePtr)
curl.cleanup()
Expand Down Expand Up @@ -159,7 +195,10 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
lift(m)
}

private def setRequestBody(curl: CurlHandle, body: GenericRequestBody[R])(implicit zone: Zone): F[CurlCode] =
private def setRequestBody(curl: CurlHandle, body: GenericRequestBody[R])(implicit
ctx: Context
): F[CurlCode] = {
implicit val z = ctx.zone
body match { // todo: assign to monad object
case b: BasicBodyPart =>
val str = basicBodyToString(b)
Expand All @@ -176,9 +215,8 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean

val otherHeaders = headers.filterNot(_.is(HeaderNames.ContentType))
if (otherHeaders.nonEmpty) {
val curlList = transformHeaders(otherHeaders)
val curlList = ctx.transformHeaders(otherHeaders)
part.withHeaders(curlList.ptr)
multiPartHeaders = multiPartHeaders :+ curlList
}
}
lift(curl.option(Mimepost, mime))
Expand All @@ -187,6 +225,7 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
case NoBody =>
monad.unit(CurlCode.Ok)
}
}

private def basicBodyToString(body: BodyPart[_]): String =
body match {
Expand Down Expand Up @@ -253,13 +292,6 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
override protected def cleanupWhenGotWebSocket(response: Nothing, e: GotAWebSocketException): F[Unit] = response
}

private def transformHeaders(reqHeaders: Iterable[Header])(implicit z: Zone): CurlList =
reqHeaders
.map(header => s"${header.name}: ${header.value}")
.foldLeft(new CurlList(null)) { case (acc, h) =>
new CurlList(acc.ptr.append(h))
}

private def toByteArray(str: String): F[Array[Byte]] = monad.unit(str.getBytes)

private def lift(code: CurlCode): F[CurlCode] =
Expand All @@ -269,6 +301,12 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
}
}

/** Curl backends that performs the curl operation with a simple `curl_easy_perform`. */
abstract class AbstractSyncCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean)
extends AbstractCurlBackend[F](_monad, verbose) {
override def performCurl(c: CurlHandle): F[CurlCode.CurlCode] = monad.unit(c.perform)
}

object AbstractCurlBackend {
val wdFunc: CFuncPtr4[Ptr[Byte], CSize, CSize, Ptr[CurlFetch], CSize] = {
(ptr: Ptr[CChar], size: CSize, nmemb: CSize, data: Ptr[CurlFetch]) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import scala.util.Try

// Curl supports redirects, but it doesn't store the history, so using FollowRedirectsBackend is more convenient

private class CurlBackend(verbose: Boolean) extends AbstractCurlBackend(IdMonad, verbose) with SyncBackend {}
private class CurlBackend(verbose: Boolean) extends AbstractSyncCurlBackend(IdMonad, verbose) with SyncBackend {}

object CurlBackend {
def apply(verbose: Boolean = false): SyncBackend = FollowRedirectsBackend(new CurlBackend(verbose))
}

private class CurlTryBackend(verbose: Boolean) extends AbstractCurlBackend(TryMonad, verbose) with Backend[Try] {}
private class CurlTryBackend(verbose: Boolean) extends AbstractSyncCurlBackend(TryMonad, verbose) with Backend[Try] {}

object CurlTryBackend {
def apply(verbose: Boolean = false): Backend[Try] = FollowRedirectsBackend(new CurlTryBackend(verbose))
Expand Down
Loading