Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Align typeclass (#1263) #1755

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions core/src/main/scala/cats/Align.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package cats

import simulacrum.typeclass

import cats.data.Ior

/**
* `Align` supports zipping together structures with different shapes,
* holding the results from either or both structures in an `Ior`.
*
* Must obey the laws in cats.laws.AlignLaws
*/
@typeclass trait Align[F[_]] extends Functor[F] { self =>

/**
* An empty structure. `align`ing with `nil` will produce the same structure as the original, mod `Ior.Left` or `Ior.Right`.
*
* Align[Option].nil[Int] = None
*/
def nil[A]: F[A]

/**
* Pairs elements of two structures along the union of their shapes, using `Ior` to hold the results.
*
* Align[List].align(List(1, 2), List(10, 11, 12)) = List(Ior.Both(1, 10), Ior.Both(2, 11), Ior.Right(12))
*/
def align[A, B](fa: F[A], fb: F[B]): F[Ior[A, B]]

/**
* Combines elements similarly to `align`, using the provided function to compute the results.
*/
def alignWith[A, B, C](fa: F[A], fb: F[B])(f: Ior[A, B] => C): F[C] =
map(align(fa, fb))(f)

/**
* Align two structures with the same element, combining results according to their semigroup instances.
*/
def salign[A : Semigroup](fa1: F[A], fa2: F[A]): F[A] =
alignWith(fa1, fa2)(_.merge)

/**
* Same as `align`, but forgets from the type that one of the two elements must be present.
*/
def padZip[A, B](fa: F[A], fb: F[B]): F[(Option[A], Option[B])] =
alignWith(fa, fb)(_.pad)

/**
* Same as `alignWith`, but forgets from the type that one of the two elements must be present.
*/
def padZipWith[A, B, C](fa: F[A], fb: F[B])(f: (Option[A], Option[B]) => C): F[C] =
alignWith(fa, fb)(ior => Function.tupled(f)(ior.pad))
}
18 changes: 16 additions & 2 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package cats
package instances

import cats.data.Ior
import cats.syntax.show._

import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

trait ListInstances extends cats.kernel.instances.ListInstances {

implicit val catsStdInstancesForList: TraverseFilter[List] with MonadCombine[List] with Monad[List] with CoflatMap[List] =
new TraverseFilter[List] with MonadCombine[List] with Monad[List] with CoflatMap[List] {
implicit val catsStdInstancesForList: TraverseFilter[List] with MonadCombine[List] with Monad[List] with CoflatMap[List] with Align[List] =
new TraverseFilter[List] with MonadCombine[List] with Monad[List] with CoflatMap[List] with Align[List] {
def empty[A]: List[A] = Nil

def combineK[A](x: List[A], y: List[A]): List[A] = x ++ y
Expand Down Expand Up @@ -115,6 +116,19 @@ trait ListInstances extends cats.kernel.instances.ListInstances {
override def dropWhile_[A](fa: List[A])(p: A => Boolean): List[A] = fa.dropWhile(p)

override def algebra[A]: Monoid[List[A]] = new kernel.instances.ListMonoid[A]

def nil[A]: List[A] = Nil

def align[A, B](fa: List[A], fb: List[B]): List[A Ior B] = {
@tailrec def loop(buf: ListBuffer[Ior[A, B]], as: List[A], bs: List[B]): List[A Ior B] =
(as, bs) match {
case (a :: atail, b :: btail) => loop(buf += Ior.Both(a, b), atail, btail)
case (Nil, Nil) => buf.toList
case (arest, Nil) => (buf ++= arest.map(Ior.left)).toList
case (Nil, brest) => (buf ++= brest.map(Ior.right)).toList
}
loop(ListBuffer.empty[Ior[A, B]], fa, fb)
}
}

implicit def catsStdShowForList[A:Show]: Show[List[A]] =
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/cats/instances/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package instances

import scala.annotation.tailrec

import cats.data.Ior

trait MapInstances extends cats.kernel.instances.MapInstances {

implicit def catsStdShowForMap[A, B](implicit showA: Show[A], showB: Show[B]): Show[Map[A, B]] =
Expand All @@ -14,8 +16,8 @@ trait MapInstances extends cats.kernel.instances.MapInstances {
}

// scalastyle:off method.length
implicit def catsStdInstancesForMap[K]: TraverseFilter[Map[K, ?]] with FlatMap[Map[K, ?]] =
new TraverseFilter[Map[K, ?]] with FlatMap[Map[K, ?]] {
implicit def catsStdInstancesForMap[K]: TraverseFilter[Map[K, ?]] with FlatMap[Map[K, ?]] with Align[Map[K, ?]] =
new TraverseFilter[Map[K, ?]] with FlatMap[Map[K, ?]] with Align[Map[K, ?]] {

override def traverse[G[_], A, B](fa: Map[K, A])(f: A => G[B])(implicit G: Applicative[G]): G[Map[K, B]] = {
val gba: Eval[G[Map[K, B]]] = Always(G.pure(Map.empty))
Expand Down Expand Up @@ -89,6 +91,22 @@ trait MapInstances extends cats.kernel.instances.MapInstances {
A.combineAll(fa.values)

override def toList[A](fa: Map[K, A]): List[A] = fa.values.toList

override def nil[A]: Map[K, A] = Map.empty[K, A]

override def align[A, B](fa: Map[K, A], fb: Map[K, B]): Map[K, A Ior B] = {
val keys = fa.keySet ++ fb.keySet
val builder = new collection.mutable.MapBuilder[K, A Ior B, Map[K, A Ior B]](Map.empty[K, A Ior B])
builder.sizeHint(keys.size)
keys.foldLeft(builder) { (builder, k) =>
(fa.get(k), fb.get(k)) match {
case (Some(a), Some(b)) => builder += k -> Ior.both(a, b)
case (Some(a), None) => builder += k -> Ior.left(a)
case (None, Some(b)) => builder += k -> Ior.right(b)
case (None, None) => ??? // should not happen
}
}.result()
}
}
// scalastyle:on method.length
}
16 changes: 14 additions & 2 deletions core/src/main/scala/cats/instances/option.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package instances

import scala.annotation.tailrec

import cats.data.Ior

trait OptionInstances extends cats.kernel.instances.OptionInstances {

implicit val catsStdInstancesForOption: TraverseFilter[Option] with MonadError[Option, Unit] with MonadCombine[Option] with Monad[Option] with CoflatMap[Option] with Alternative[Option] =
new TraverseFilter[Option] with MonadError[Option, Unit] with MonadCombine[Option] with Monad[Option] with CoflatMap[Option] with Alternative[Option] {
implicit val catsStdInstancesForOption: TraverseFilter[Option] with MonadError[Option, Unit] with MonadCombine[Option] with Monad[Option] with CoflatMap[Option] with Alternative[Option] with Align[Option] =
new TraverseFilter[Option] with MonadError[Option, Unit] with MonadCombine[Option] with Monad[Option] with CoflatMap[Option] with Alternative[Option] with Align[Option] {

def empty[A]: Option[A] = None

Expand Down Expand Up @@ -116,6 +118,16 @@ trait OptionInstances extends cats.kernel.instances.OptionInstances {

override def isEmpty[A](fa: Option[A]): Boolean =
fa.isEmpty

override def nil[A]: Option[A] = None

override def align[A, B](fa: Option[A], fb: Option[B]): Option[A Ior B] =
(fa, fb) match {
case (None, None) => None
case (Some(a), None) => Some(Ior.left(a))
case (None, Some(b)) => Some(Ior.right(b))
case (Some(a), Some(b)) => Some(Ior.both(a, b))
}
}

implicit def catsStdShowForOption[A](implicit A: Show[A]): Show[Option[A]] =
Expand Down
16 changes: 14 additions & 2 deletions core/src/main/scala/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package cats
package instances

import cats.syntax.show._
import cats.data.Ior

import scala.annotation.tailrec

trait StreamInstances extends cats.kernel.instances.StreamInstances {
implicit val catsStdInstancesForStream: TraverseFilter[Stream] with MonadCombine[Stream] with CoflatMap[Stream] =
new TraverseFilter[Stream] with MonadCombine[Stream] with CoflatMap[Stream] {
implicit val catsStdInstancesForStream: TraverseFilter[Stream] with MonadCombine[Stream] with CoflatMap[Stream] with Align[Stream] =
new TraverseFilter[Stream] with MonadCombine[Stream] with CoflatMap[Stream] with Align[Stream] {

def empty[A]: Stream[A] = Stream.Empty

Expand Down Expand Up @@ -141,6 +142,17 @@ trait StreamInstances extends cats.kernel.instances.StreamInstances {
override def find[A](fa: Stream[A])(f: A => Boolean): Option[A] = fa.find(f)

override def algebra[A]: Monoid[Stream[A]] = new kernel.instances.StreamMonoid[A]

override def nil[A]: Stream[A] = Stream.Empty

override def align[A, B](fa: Stream[A], fb: Stream[B]): Stream[A Ior B] =
(fa, fb) match {
case ((a #:: atail), (b #:: btail)) => Ior.both(a, b) #:: align(atail, btail)
case (Stream.Empty, Stream.Empty) => Stream.Empty
case (arest, Stream.Empty) => arest.map(Ior.left)
case (Stream.Empty, brest) => brest.map(Ior.right)
}

}

implicit def catsStdShowForStream[A: Show]: Show[Stream[A]] =
Expand Down
16 changes: 14 additions & 2 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package cats
package instances

import cats.syntax.show._
import cats.data.Ior
import scala.annotation.tailrec
import scala.collection.+:
import scala.collection.immutable.VectorBuilder
import list._

trait VectorInstances extends cats.kernel.instances.VectorInstances {
implicit val catsStdInstancesForVector: TraverseFilter[Vector] with MonadCombine[Vector] with CoflatMap[Vector] =
new TraverseFilter[Vector] with MonadCombine[Vector] with CoflatMap[Vector] {
implicit val catsStdInstancesForVector: TraverseFilter[Vector] with MonadCombine[Vector] with CoflatMap[Vector] with Align[Vector] =
new TraverseFilter[Vector] with MonadCombine[Vector] with CoflatMap[Vector] with Align[Vector] {

def empty[A]: Vector[A] = Vector.empty[A]

Expand Down Expand Up @@ -104,6 +105,17 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
override def find[A](fa: Vector[A])(f: A => Boolean): Option[A] = fa.find(f)

override def algebra[A]: Monoid[Vector[A]] = new kernel.instances.VectorMonoid[A]

override def nil[A]: Vector[A] = Vector.empty[A]

override def align[A, B](fa: Vector[A], fb: Vector[B]): Vector[A Ior B] = {
val aLarger = fa.size >= fb.size
if (aLarger) {
(fa, fb).zipped.map(Ior.both) ++ fa.drop(fb.size).map(Ior.left)
} else {
(fa, fb).zipped.map(Ior.both) ++ fb.drop(fa.size).map(Ior.right)
}
}
}

implicit def catsStdShowForVector[A:Show]: Show[Vector[A]] =
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/cats/syntax/align.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package cats
package syntax

trait AlignSyntax {
implicit final def catsSyntaxAlign[F[_], A](fa: F[A])(implicit F: Align[F]): Align.Ops[F, A] =
new Align.Ops[F, A] {
val self = fa
val typeClassInstance = F
}
}
3 changes: 2 additions & 1 deletion core/src/main/scala/cats/syntax/all.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package cats
package syntax

trait AllSyntax
extends ApplicativeSyntax
extends AlignSyntax
with ApplicativeSyntax
with ApplicativeErrorSyntax
with ApplySyntax
with BifunctorSyntax
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/cats/syntax/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats

package object syntax {
object all extends AllSyntax
object align extends AlignSyntax
object applicative extends ApplicativeSyntax
object applicativeError extends ApplicativeErrorSyntax
object apply extends ApplySyntax
Expand Down
34 changes: 34 additions & 0 deletions laws/src/main/scala/cats/laws/AlignLaws.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package cats
package laws

import cats.syntax.align._
import cats.syntax.functor._

import cats.data.Ior

/**
* Laws that must be obeyed by any `Align`.
*/
trait AlignLaws[F[_]] extends FunctorLaws[F] {
implicit override def F: Align[F]

def nilLeftIdentity[A, B](fb: F[B]): IsEq[F[A Ior B]] =
F.nil[A].align(fb) <-> fb.map(Ior.right)

def nilRightIdentity[A, B](fa: F[A]): IsEq[F[A Ior B]] =
fa.align(F.nil[B]) <-> fa.map(Ior.left)

def alignSelfBoth[A](fa: F[A]): IsEq[F[A Ior A]] =
fa.align(fa) <-> fa.map(a => Ior.both(a, a))

def alignHomomorphism[A, B, C, D](fa: F[A], fb: F[B], f: A => C, g: B => D): IsEq[F[C Ior D]] =
fa.map(f).align(fb.map(g)) <-> fa.align(fb).map(_.bimap(f, g))

def alignWithConsistent[A, B, C](fa: F[A], fb: F[B], f: A Ior B => C): IsEq[F[C]] =
fa.alignWith(fb)(f) <-> fa.align(fb).map(f)
}

object AlignLaws {
def apply[F[_]](implicit ev: Align[F]): AlignLaws[F] =
new AlignLaws[F] { def F: Align[F] = ev }
}
46 changes: 46 additions & 0 deletions laws/src/main/scala/cats/laws/discipline/AlignTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package cats
package laws
package discipline

import org.scalacheck.{Arbitrary, Cogen, Prop}
import Prop._

import cats.data.Ior

trait AlignTests[F[_]] extends FunctorTests[F] {
def laws: AlignLaws[F]

def align[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary](
implicit ArbFA: Arbitrary[F[A]],
ArbFB: Arbitrary[F[B]],
ArbFC: Arbitrary[F[C]],
ArbFAtoB: Arbitrary[A => C],
ArbFBtoC: Arbitrary[B => D],
ArbIorABtoC: Arbitrary[A Ior B => C],
CogenA: Cogen[A],
CogenB: Cogen[B],
CogenC: Cogen[C],
EqFA: Eq[F[A]],
EqFB: Eq[F[B]],
EqFC: Eq[F[C]],
EqFIorAA: Eq[F[A Ior A]],
EqFIorAB: Eq[F[A Ior B]],
EqFIorCD: Eq[F[C Ior D]]
): RuleSet = new DefaultRuleSet(
name = "align",
parent = Some(functor[A, B, C]),
"nil left identity" -> forAll(laws.nilLeftIdentity[A, B] _),
"nil right identity" -> forAll(laws.nilRightIdentity[A, B] _),
"align self both" -> forAll(laws.alignSelfBoth[A] _),
"align homomorphism" -> forAll { (fa: F[A], fb: F[B], f: A => C, g: B => D) =>
laws.alignHomomorphism[A, B, C, D](fa, fb, f, g)
},
"alignWith consistent" -> forAll { (fa: F[A], fb: F[B], f: A Ior B => C) =>
laws.alignWithConsistent[A, B, C](fa, fb, f)
})
}

object AlignTests {
def apply[F[_]: Align]: AlignTests[F] =
new AlignTests[F] { def laws: AlignLaws[F] = AlignLaws[F] }
}
5 changes: 4 additions & 1 deletion tests/src/test/scala/cats/tests/ListTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cats
package tests

import cats.data.NonEmptyList
import cats.laws.discipline.{TraverseFilterTests, CoflatMapTests, MonadCombineTests, SerializableTests, CartesianTests}
import cats.laws.discipline.{AlignTests, TraverseFilterTests, CoflatMapTests, MonadCombineTests, SerializableTests, CartesianTests}
import cats.laws.discipline.arbitrary._

class ListTests extends CatsSuite {
Expand All @@ -19,6 +19,9 @@ class ListTests extends CatsSuite {
checkAll("List[Int] with Option", TraverseFilterTests[List].traverseFilter[Int, Int, Int, List[Int], Option, Option])
checkAll("TraverseFilter[List]", SerializableTests.serializable(TraverseFilter[List]))

checkAll("List[Int]", AlignTests[List].align[Int, Int, Int, Int])
checkAll("Align[List]", SerializableTests.serializable(Align[List]))

test("nel => list => nel returns original nel")(
forAll { fa: NonEmptyList[Int] =>
fa.toList.toNel should === (Some(fa))
Expand Down
6 changes: 5 additions & 1 deletion tests/src/test/scala/cats/tests/MapTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package cats
package tests

import cats.laws.discipline.{TraverseFilterTests, FlatMapTests, SerializableTests, CartesianTests}
import cats.laws.discipline.{AlignTests, TraverseFilterTests, FlatMapTests, SerializableTests, CartesianTests}
import cats.laws.discipline.arbitrary._

class MapTests extends CatsSuite {
implicit val iso = CartesianTests.Isomorphisms.invariant[Map[Int, ?]]
Expand All @@ -15,6 +16,9 @@ class MapTests extends CatsSuite {
checkAll("Map[Int, Int] with Option", TraverseFilterTests[Map[Int, ?]].traverseFilter[Int, Int, Int, Int, Option, Option])
checkAll("TraverseFilter[Map[Int, ?]]", SerializableTests.serializable(TraverseFilter[Map[Int, ?]]))

checkAll("Map[Int, Int]", AlignTests[Map[Int, ?]].align[Int, Int, Int, Int])
checkAll("Align[Map]", SerializableTests.serializable(Align[Map[Int, ?]]))

test("show isn't empty and is formatted as expected") {
forAll { (map: Map[Int, String]) =>
map.show.nonEmpty should === (true)
Expand Down
Loading