Skip to content

Commit

Permalink
Merge pull request #70 from FineCinnamon/rr-free
Browse files Browse the repository at this point in the history
Free + stack safe monads
  • Loading branch information
ffgiraldez authored Apr 18, 2017
2 parents bea81d0 + b768a26 commit a087340
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 38 deletions.
6 changes: 6 additions & 0 deletions katz/src/main/kotlin/katz/arrow/FunctionK.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ interface FunctionK<in F, out G> {
*/
operator fun <A> invoke(fa: HK<F, A>): HK<G, A>

companion object {
fun <F> id(): FunctionK<F, F> = object : FunctionK<F, F> {
override fun <A> invoke(fa: HK<F, A>): HK<F, A> = fa
}
}

}

fun <F, G, H> FunctionK<F, G>.or(h: FunctionK<H, G>): FunctionK<CoproductFG<F, H>, G> =
Expand Down
8 changes: 5 additions & 3 deletions katz/src/main/kotlin/katz/data/NonEmptyList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class NonEmptyList<out A> private constructor(
fun <B> flatMap(f: (A) -> NonEmptyList<B>): NonEmptyList<B> =
f(head) + tail.flatMap { f(it).all }

operator fun <A> NonEmptyList<A>.plus(l: NonEmptyList<A>): NonEmptyList<A> = NonEmptyList(all + l.all)
infix operator fun <A> NonEmptyList<A>.plus(l: NonEmptyList<A>): NonEmptyList<A> = NonEmptyList(all + l.all)

operator fun <A> NonEmptyList<A>.plus(l: List<A>): NonEmptyList<A> = NonEmptyList(all + l)
infix operator fun <A> NonEmptyList<A>.plus(l: List<A>): NonEmptyList<A> = NonEmptyList(all + l)

operator fun <A> NonEmptyList<A>.plus(a: A): NonEmptyList<A> = NonEmptyList(all + a)
infix operator fun <A> NonEmptyList<A>.plus(a: A): NonEmptyList<A> = NonEmptyList(all + a)

fun iterator(): Iterator<A> = all.iterator()

Expand All @@ -61,5 +61,7 @@ class NonEmptyList<out A> private constructor(

companion object : NonEmptyListMonad, GlobalInstance<Monad<NonEmptyList.F>>() {
fun <A> of(head: A, vararg t: A): NonEmptyList<A> = NonEmptyList(head, t.asList())
fun <A> fromList(l: List<A>): Option<NonEmptyList<A>> = if (l.isEmpty()) Option.None else Option.Some(NonEmptyList(l))
fun <A> fromListUnsafe(l: List<A>): NonEmptyList<A> = NonEmptyList(l)
}
}
58 changes: 58 additions & 0 deletions katz/src/main/kotlin/katz/free/Free.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package katz

typealias FreeKind<S, A> = HK2<Free.F, S, A>
typealias FreeF<S> = HK<Free.F, S>

fun <S, A> FreeKind<S, A>.ev(): Free<S, A> = this as Free<S, A>

sealed class Free<out S, out A> : FreeKind<S, A> {

class F private constructor()

companion object {
fun <S, A> pure(a: A): Free<S, A> = Pure(a)
fun <S, A> liftF(fa: HK<S, A>): Free<S, A> = Suspend(fa)
}

data class Pure<out S, out A>(val a: A) : Free<S, A>()
data class Suspend<out S, out A>(val a: HK<S, A>) : Free<S, A>()
data class FlatMapped<out S, out B, C>(val c: Free<S, C>, val f: (C) -> Free<S, B>) : Free<S, B>()

override fun toString(): String = "Free(...) : toString is not stack-safe"
}

fun <S, A, B> Free<S, A>.map(f: (A) -> B): Free<S, B> =
flatMap { Free.Pure<S, B>(f(it)) }

fun <S, A, B> Free<S, A>.flatMap(f: (A) -> Free<S, B>): Free<S, B> =
Free.FlatMapped(this, f)

@Suppress("UNCHECKED_CAST")
tailrec fun <S, A> Free<S, A>.step(): Free<S, A> =
if (this is Free.FlatMapped<S, A, *> && this.c is Free.FlatMapped<S, *, *>) {
val g = this.f as (A) -> Free<S, A>
val c = this.c.c as Free<S, A>
val f = this.c.f as (A) -> Free<S, A>
c.flatMap { cc -> f(cc).flatMap(g) }.step()
} else if (this is Free.FlatMapped<S, A, *> && this.c is Free.Pure<S, *>) {
val a = this.c.a as A
val f = this.f as (A) -> Free<S, A>
f(a).step()
} else {
this
}

@Suppress("UNCHECKED_CAST")
fun <M, S, A> Free<S, A>.foldMap(MM: Monad<M>, f: FunctionK<S, M>): HK<M, A> =
MM.tailRecM(this) {
val x = it.step()
when (x) {
is Free.Pure<S, A> -> MM.pure(Either.Right(x.a))
is Free.Suspend<S, A> -> MM.map(f(x.a), { Either.Right(it) })
is Free.FlatMapped<S, A, *> -> {
val g = (x.f as (A) -> Free<S, A>)
val c = x.c as Free<S, A>
MM.map(c.foldMap(MM, f), { cc -> Either.Left(g(cc)) })
}
}
}
13 changes: 12 additions & 1 deletion katz/src/main/kotlin/katz/instances/EitherMonad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ class EitherMonad<L> : Monad<EitherF<L>> {
override fun <A, B> flatMap(fa: EitherKind<L, A>, f: (A) -> EitherKind<L, B>): Either<L, B> {
return fa.ev().flatMap { f(it).ev() }
}

tailrec override fun <A, B> tailRecM(a: A, f: (A) -> HK<EitherF<L>, Either<A, B>>): Either<L, B> {
val e = f(a).ev().ev()
return when (e) {
is Either.Left -> e
is Either.Right -> when (e.b) {
is Either.Left -> tailRecM(e.b.a, f)
is Either.Right -> e.b
}
}
}
}

fun <A, B> EitherKind<A, B>.ev(): Either<A, B> = this as Either<A, B>
fun <A, B> EitherKind<A, B>.ev(): Either<A, B> = this as Either<A, B>
21 changes: 21 additions & 0 deletions katz/src/main/kotlin/katz/instances/FreeMonad.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package katz

interface FreeMonad<S> : Monad<FreeF<S>>, Typeclass {
override fun <A> pure(a: A): Free<S, A> =
Free.pure(a)

override fun <A, B> map(fa: FreeKind<S, A>, f: (A) -> B): HK<FreeF<S>, B> =
fa.ev().map(f)

override fun <A, B> flatMap(fa: FreeKind<S, A>, f: (A) -> FreeKind<S, B>): Free<S, B> =
fa.ev().flatMap { f(it).ev() }

override fun <A, B> tailRecM(a: A, f: (A) -> FreeKind<S, Either<A, B>>): Free<S, B> {
return f(a).ev().flatMap {
when (it) {
is Either.Left -> tailRecM(it.a, f)
is Either.Right -> pure(it.b)
}
}
}
}
9 changes: 9 additions & 0 deletions katz/src/main/kotlin/katz/instances/IdMonad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,13 @@ interface IdMonad : Monad<Id.F> {

override fun <A, B> flatMap(fa: IdKind<A>, f: (A) -> IdKind<B>): Id<B> =
fa.ev().flatMap { f(it).ev() }

@Suppress("UNCHECKED_CAST")
tailrec override fun <A, B> tailRecM(a: A, f: (A) -> IdKind<Either<A, B>>): Id<B> {
val x = f(a).ev().value
return when (x) {
is Either.Left<A> -> tailRecM(x.a, f)
is Either.Right<B> -> Id(x.b)
}
}
}
24 changes: 24 additions & 0 deletions katz/src/main/kotlin/katz/instances/IorMonad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@ class IorMonad<L>(val SL: Semigroup<L>) : Monad<HK<Ior.F, L>> {

override fun <A> pure(a: A): Ior<L, A> = Ior.Right(a)

private tailrec fun <A, B> loop(v: Ior<L, Either<A, B>>, f: (A) -> IorKind<L, Either<A, B>>): Ior<L, B> {
return when (v) {
is Ior.Left -> Ior.Left(v.value)
is Ior.Right -> when (v.value) {
is Either.Right -> Ior.Right(v.value.b)
is Either.Left -> loop(f(v.value.a).ev().ev(), f)
}
is Ior.Both -> when (v.rightValue) {
is Either.Right -> Ior.Both(v.leftValue, v.rightValue.b)
is Either.Left -> {
val fnb = f(v.rightValue.a).ev()
when (fnb) {
is Ior.Left -> Ior.Left(SL.combine(v.leftValue, fnb.value))
is Ior.Right -> loop(Ior.Both(v.leftValue, fnb.value), f)
is Ior.Both -> loop(Ior.Both(SL.combine(v.leftValue, fnb.leftValue), fnb.rightValue), f)
}
}
}
}
}

override fun <A, B> tailRecM(a: A, f: (A) -> IorKind<L, Either<A, B>>): Ior<L, B> {
return loop(f(a).ev(), f)
}
}

fun <A, B> IorKind<A, B>.ev(): Ior<A, B> = this as Ior<A, B>
19 changes: 19 additions & 0 deletions katz/src/main/kotlin/katz/instances/NonEmptyListMonad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@ interface NonEmptyListMonad : Monad<NonEmptyList.F> {
override fun <A, B> flatMap(fa: NonEmptyListKind<A>, f: (A) -> NonEmptyListKind<B>): NonEmptyList<B> =
fa.ev().flatMap { f(it).ev() }

@Suppress("UNCHECKED_CAST")
private tailrec fun <A, B> go(buf: ArrayList<B>, f: (A) -> HK<NonEmptyList.F, Either<A, B>>, v: NonEmptyList<Either<A, B>>): Unit =
when (v.head) {
is Either.Right<*> -> {
buf += v.head.b as B
val x = NonEmptyList.fromList(v.tail)
when (x) {
is Option.Some<NonEmptyList<Either<A, B>>> -> go(buf, f, x.value)
is Option.None -> Unit
}
}
is Either.Left<*> -> go(buf, f, NonEmptyList.fromListUnsafe(f(v.head.a as A).ev().all + v.tail))
}

override fun <A, B> tailRecM(a: A, f: (A) -> HK<NonEmptyList.F, Either<A, B>>): NonEmptyList<B> {
val buf = ArrayList<B>()
go(buf, f, f(a).ev())
return NonEmptyList.fromListUnsafe(buf)
}
}

fun <A> NonEmptyListKind<A>.ev(): NonEmptyList<A> = this as NonEmptyList<A>
14 changes: 13 additions & 1 deletion katz/src/main/kotlin/katz/instances/OptionMonad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,19 @@ interface OptionMonad : Monad<Option.F> {

override fun <A, B> flatMap(fa: OptionKind<A>, f: (A) -> OptionKind<B>): Option<B> =
fa.ev().flatMap { f(it).ev() }

tailrec override fun <A, B> tailRecM(a: A, f: (A) -> HK<Option.F, Either<A, B>>): Option<B> {
val option = f(a).ev()
return when (option) {
is Option.Some -> {
when (option.value) {
is Either.Left -> tailRecM(option.value.a, f)
is Either.Right -> Option.Some(option.value.b)
}
}
is Option.None -> Option.None
}
}
}

fun <A> OptionKind<A>.ev(): Option<A> = this as Option<A>

13 changes: 12 additions & 1 deletion katz/src/main/kotlin/katz/instances/OptionTMonad.kt
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
package katz

class OptionTMonad<F>(val MF : Monad<F>) : Monad<OptionTF<F>> {
class OptionTMonad<F>(val MF: Monad<F>) : Monad<OptionTF<F>> {
override fun <A> pure(a: A): OptionT<F, A> = OptionT(MF, MF.pure(Option(a)))

override fun <A, B> flatMap(fa: OptionTKind<F, A>, f: (A) -> OptionTKind<F, B>): OptionT<F, B> =
fa.ev().flatMap { f(it).ev() }

override fun <A, B> map(fa: OptionTKind<F, A>, f: (A) -> B): OptionT<F, B> =
fa.ev().map(f)

override fun <A, B> tailRecM(a: A, f: (A) -> HK<OptionTF<F>, Either<A, B>>): OptionT<F, B> =
OptionT(MF, MF.tailRecM(a, {
MF.map(f(it).ev().value, {
it.fold({
Either.Right<Option<B>>(Option.None)
}, {
it.map { Option.Some(it) }
})
})
}))
}

fun <F, A> OptionTKind<F, A>.ev(): OptionT<F, A> = this as OptionT<F, A>
8 changes: 8 additions & 0 deletions katz/src/main/kotlin/katz/instances/TryMonadError.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ interface TryMonadError : MonadError<Try.F, Throwable> {

override fun <A> handleErrorWith(fa: TryKind<A>, f: (Throwable) -> TryKind<A>): Try<A> =
fa.ev().recoverWith { f(it).ev() }

@Suppress("UNCHECKED_CAST")
override fun <A, B> tailRecM(a: A, f: (A) -> TryKind<Either<A, B>>): Try<B> {
val x = f(a).ev()
return if (x is Try.Success && x.value is Either.Left<A>) tailRecM(x.value.a, f)
else if (x is Try.Success && x.value is Either.Right<B>) Try.Success(x.value.b)
else x as Try.Failure<B>
}
}

fun <A> TryKind<A>.ev(): Try<A> = this as Try<A>
2 changes: 2 additions & 0 deletions katz/src/main/kotlin/katz/typeclasses/Monad.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ interface Monad<F> : Applicative<F>, Typeclass {

fun <A> flatten(ffa: HK<F, HK<F, A>>): HK<F, A> =
flatMap(ffa, { it })

fun <A, B> tailRecM(a: A, f: (A) -> HK<F, Either<A, B>>) : HK<F, B>
}

@RestrictsSuspension
Expand Down
87 changes: 87 additions & 0 deletions katz/src/test/kotlin/katz/free/FreeTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package katz

import io.kotlintest.KTestJUnitRunner
import io.kotlintest.matchers.shouldBe
import org.junit.runner.RunWith

sealed class Ops<A> : HK<Ops.F, A> {

class F private constructor()

data class Value(val a: Int) : Ops<Int>()
data class Add(val a: Int, val y: Int) : Ops<Int>()
data class Subtract(val a: Int, val y: Int) : Ops<Int>()

companion object : FreeMonad<Ops.F> {
fun value(n: Int): Free<Ops.F, Int> = Free.liftF(Ops.Value(n))
fun add(n: Int, y: Int): Free<Ops.F, Int> = Free.liftF(Ops.Add(n, y))
fun subtract(n: Int, y: Int): Free<Ops.F, Int> = Free.liftF(Ops.Subtract(n, y))
}
}

fun <A> HK<Ops.F, A>.ev(): Ops<A> = this as Ops<A>

val optionInterpreter: FunctionK<Ops.F, Option.F> = object : FunctionK<Ops.F, Option.F> {
override fun <A> invoke(fa: HK<Ops.F, A>): Option<A> {
val op = fa.ev()
return when (op) {
is Ops.Add -> Option.Some(op.a + op.y)
is Ops.Subtract -> Option.Some(op.a - op.y)
is Ops.Value -> Option.Some(op.a)
} as Option<A>
}
}

val nonEmptyListInterpter: FunctionK<Ops.F, NonEmptyList.F> = object : FunctionK<Ops.F, NonEmptyList.F> {
override fun <A> invoke(fa: HK<Ops.F, A>): NonEmptyList<A> {
val op = fa.ev()
return when (op) {
is Ops.Add -> NonEmptyList.of(op.a + op.y)
is Ops.Subtract -> NonEmptyList.of(op.a - op.y)
is Ops.Value -> NonEmptyList.of(op.a)
} as NonEmptyList<A>
}
}

val idInterpreter: FunctionK<Ops.F, Id.F> = object : FunctionK<Ops.F, Id.F> {
override fun <A> invoke(fa: HK<Ops.F, A>): Id<A> {
val op = fa.ev()
return when (op) {
is Ops.Add -> Id(op.a + op.y)
is Ops.Subtract -> Id(op.a - op.y)
is Ops.Value -> Id(op.a)
} as Id<A>
}
}

@RunWith(KTestJUnitRunner::class)
class FreeTest : UnitSpec() {

val program = Ops.binding {
val added = !Ops.add(10, 10)
val substracted = !Ops.subtract(added, 50)
yields(substracted)
}.ev()

fun stackSafeTestProgram(n: Int, stopAt: Int): Free<Ops.F, Int> = Ops.binding {
val v = !Ops.add(n, 1)
val r = !if (v < stopAt) stackSafeTestProgram(v, stopAt) else Free.pure<Ops.F, Int>(v)
yields(r)
}.ev()

init {

"Can interpret an ADT as Free operations" {
program.foldMap(Option, optionInterpreter).ev() shouldBe Option.Some(-30)
program.foldMap(Id, idInterpreter).ev() shouldBe Id(-30)
program.foldMap(NonEmptyList, nonEmptyListInterpter).ev() shouldBe NonEmptyList.of(-30)
}

"foldMap is stack safe" {
val n = 50000
val hugeProg = stackSafeTestProgram(0, n)
hugeProg.foldMap(Id, idInterpreter).value() shouldBe n
}

}
}
Loading

0 comments on commit a087340

Please sign in to comment.