Skip to content

Commit

Permalink
Implemented pure transforming of context function
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Sep 7, 2024
1 parent a565b9b commit e3ba29e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ trait ApplyArgRecordScope[F[_], CT, CC<:CpsMonadContext[F]]:
val paramTypes = params.map(_.tpt.tpe)
cpsBody.syncOrigin match
case Some(syncBody) =>
if (cpsBody.isChanged) then
if (term.tpe.isContextFunctionType && !allowUncontext) then
println(s"context cpsBody = ${cpsBody}, ")
throw MacroError("Can't transform context function: TastyAPI don;t support this yet",posExpr(term))
val mt = MethodType(paramNames)(_ => paramTypes, _ => syncBody.tpe.widen)
if (cpsBody.isChanged) then
val methodKind = if (term.tpe.isContextFunctionType && !allowUncontext) MethodTypeKind.Contextual else MethodTypeKind.Plain
//if (term.tpe.isContextFunctionType && !allowUncontext) then
// val mt = MethodType(MethodTypeKind.Contextual)(paramNames)(_ => paramTypes, _ => syncBody.tpe.widen)
// println(s"context cpsBody = ${cpsBody}, ")
// throw MacroError("Can't transform context function: TastyAPI don;t support this yet",posExpr(term))
val mt = MethodType(methodKind)(paramNames)(_ => paramTypes, _ => syncBody.tpe.widen)
Lambda(owner, mt,
(owner,args) => changeArgs(params,args,syncBody,owner).changeOwner(owner))
else
Expand Down
2 changes: 2 additions & 0 deletions shared/src/test/scala/cpstest/TestFunBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.junit.{Test,Ignore}

class TestFunBlock {



@Test
def testApplyWithFunBlock(): Unit = {
var x = 0
Expand Down
65 changes: 65 additions & 0 deletions shared/src/test/scala/cpstest/TestNonShiftedContextFunction.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package cpstest

import scala.util.*

import cps.{*,given}
import cps.monads.{*, given}

import org.junit.{Test,Ignore}

class TestNonShiftedContextFunction {

case class Context(x:String)



object O {

def apply[T](c:Context ?=> T): T = c(using Context("O"))

def p0(x: => Int): Int = x + 1
def p0_async(x:() => ComputationBound[Int]): Int =
x().run().get + 1

def p1(x: Int=>Int): Int = x(1) + 1
def p1_async(x: Int=>ComputationBound[Int]): Int =
x(1).run().get + 1

}

@Test
def testApplyContextFunctionP1(): Unit = {
var x = 0
val c = async[ComputationBound] {
val a = await(T1.cbi(2))
O {
//val _ = ((x:Int) => x + await(T1.cbi(3)))
val q = O.p1( x => x+ await(T1.cbi(3)) )
if (false) then
println(s"ctx=${summon[Context]}")
a + q
}
}
//println(s"run=${c.run()}")
assert(c.run() == Success(7))
}


@Test
def testApplyContextFunctionP0(): Unit = {
var x = 0
val c = async[ComputationBound] {
val a = await(T1.cbi(2))
O {
//val _ = ((x:Int) => x + await(T1.cbi(3)))
val q = O.p0(await(T1.cbi(3)))
if (false) then
println(s"ctx=${summon[Context]}")
a + q
}
}
assert(c.run() == Success(6))
}


}

0 comments on commit e3ba29e

Please sign in to comment.