Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return stack-safe Function0 and Function1 from Semigroup#combine #4093

Merged
merged 4 commits into from
Dec 23, 2021

Conversation

mrdziuban
Copy link
Contributor

Fixes #4089.

This adds CombineFunction0 and CombineFunction1 classes as suggested by @johnynek in this comment to call and combine the results of two functions in a stack-safe way. Two things to note:

  • This won't magically make any function passed to Semigroup#combine stack-safe, it just ensures that the combined function itself won't be the cause of a StackOverflowError
  • I haven't added anything special to address the other idea of tracking call depth. If others feel this is important, I can look into doing so but may need some guidance

    Lastly, you could track the depth of CombineFn when you are building it so you do the expensive apply when the depth gets greater than something like 500 or something, and do the naive thing when depth is < than that (something like AndThen's hack to use a constant amount of stack to do applications).

@armanbilge armanbilge requested a review from johnynek December 21, 2021 19:19
Copy link
Contributor

@johnynek johnynek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on!

for {
lb <- tailcall(call(l, a))
rb <- tailcall(call(r, a))
} yield sg.combine(lb, rb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like scala 3 doesn't like this line. I think we could possibly fix it by doing:

private def call[C](fn: A => C, a: A): TailRec[C] =

and not tie the result type to B. Since Semigroup is invariant in B but A => B is covariant in B, actually the pattern match could give a B1 <: B. I think TailRec is covariant, so this should be fine... maybe scala 3 has a bug here, or I'm missing something...

Copy link
Contributor Author

@mrdziuban mrdziuban Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That fixes it in scala 3 but breaks it in scala 2 😢

[error] FunctionInstances.scala:119:28: type mismatch;
[error]  found   : lc.type (with underlying type C)
[error]  required: ?C1 where type ?C1 <: C (this is a GADT skolem)
[error]         } yield sg.combine(lc, rc)
[error]                            ^

What I found does work is using a type ascription in the match along with an @unchecked annotation, similar to what AndThen does here.

I went ahead and committed that but would be happy to revert it if there's a better solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to work for me:

final case class CombineFunction1[-A, B](left: A => B, right: A => B, semiB: Semigroup[B]) extends (A => B) {
    private[this] def call[C](fn: A => C, a: A): TailRec[C] =
      fn match {
        case ref: CombineFunction1[A, _] =>
          for {
            lb <- tailcall(call(ref.left, a))
            rb <- tailcall(call(ref.right, a))
          } yield ref.semiB.combine(lb, rb)
        case _ => done(fn(a))
      }

    override def apply(a: A): B = call(this, a).result
  }

@mrdziuban mrdziuban force-pushed the fn-semigroup-stack-safety branch from f2b247c to 8c47f0a Compare December 21, 2021 19:50
johnynek
johnynek previously approved these changes Dec 21, 2021
@joroKr21
Copy link
Member

joroKr21 commented Dec 22, 2021

What if we had simply:

case class CombineFunctions[A, B: Semigroup](fns: NonEmptyChain[A => B]) extends (A => B) {
  def apply(a: A) = fns.reduceMap(_.apply(a))
}

@johnynek
Copy link
Contributor

johnynek commented Dec 22, 2021

This with NonEmptyChain would work if that was accessible in kernel, although I'm not sure NonEmptyChain's reduceMap would be any faster than what we have here. I think we would want benchmarks.

(Edit: my original comment misread and thought you had NonEmptyList).

Also, passing the Semigroup as an implicit means it would not be used for equality which I think is a negative personally (since I've often leveraged AST equality in caching of interpreters).

@joroKr21
Copy link
Member

joroKr21 commented Dec 22, 2021

This with NonEmptyChain would work if that was accessible in kernel, although I'm not sure NonEmptyChain's reduceMap would be any faster than what we have here. I think we would want benchmarks.

Hmm we could use Vector then or LazyList/Stream. The main point is building a collection of functions on combine (by checking both sides if they are already CombineFunctions) and then the evaluation is reduce - no need to trampoline. We already know that it will be non-empty so we can unsafely reduce.

Also, passing the Semigroup as an implicit means it would not be used for equality which I think is a negative personally (since I've often leveraged AST equality in caching of interpreters).

Doesn't the current implementation suffer from the same problem? - i.e. if we have CombineFunciton on both sides we should also check that the Semigroup is the same.

@johnynek
Copy link
Contributor

repeat concatenation of Vector can result in O(N^2) performance. LazyList/Stream concatenations can result in stack overflow.

Doesn't the current implementation suffer from the same problem? - i.e. if we have CombineFunciton on both sides we should also check that the Semigroup is the same.

Note the current implementation only uses that semigroup for the outer most pair to combine. If an inner pair is also a CombineFunction then we use the semigroup associated with that. We should add a test for that, but if we don't have that we will definitely introduce some confusing bugs. So, I think the current implementation does correctly implement equals structurally.

Also, now that I think about it, I don't see how you can use the collection approach without storing the semigroups since in principle, they can all be different.

I'm not that attached to this problem. If you don't care about the O(N^2), I guess that's fine, since I don't think I've ever used this monoid, but consider the problem: if you are overflowing the stack, you have a long list, if you get N^2 performance there it is going to be slow.

If you really want it written another way, I'd recommend 1. making a benchmark, 2. making another PR. @mrdziuban has solved the issue, and spent the time to PR.

@joroKr21
Copy link
Member

I'm not very attached to this problem either - just had a different idea.

repeat concatenation of Vector can result in O(N^2) performance.

That doesn't sound right. Why do you think it would be O(N^2)?

LazyList/Stream concatenations can result in stack overflow.

Yeah, I guess if you append a lot you get the same problem.

Note the current implementation only uses that semigroup for the outer most pair to combine. If an inner pair is also a CombineFunction then we use the semigroup associated with that. We should add a test for that, but if we don't have that we will definitely introduce some confusing bugs. So, I think the current implementation does correctly implement equals structurally.

I'm not sure what you're saying but after looking a bit closer I managed to convince myself that the it would work correctly. On the other hand we don't use the structural equality of case classes here.

Also, now that I think about it, I don't see how you can use the collection approach without storing the semigroups since in principle, they can all be different.

Yes, you would store the semigroup either way.

If you really want it written another way, I'd recommend 1. making a benchmark, 2. making another PR. @mrdziuban has solved the issue, and spent the time to PR.

Yes, of course, kudos to @mrdziuban for tackling this tricky problem! I just wanted to poke a bit more at the problem to see if we can find a solution which doesn't have to pay the cost of trampolining every time we invoke the resulting function but it doesn't seem to be the case.

@joroKr21
Copy link
Member

Another approach I can think of would be to write a balance function that converts a chain of functions to a tree of functions (((((f |+| g) |+| h) |+| i) |+| j) => (((f |+| g) |+| (h |+| i)) |+| j) so that we need O(log2(N)) stack instead of O(N) stack but that may prove too complicated for this use case.

@johnynek
Copy link
Contributor

I'm not very attached to this problem either - just had a different idea.

repeat concatenation of Vector can result in O(N^2) performance.

That doesn't sound right. Why do you think it would be O(N^2)?

Why do you think it doesn't? 😀 I thought, e.g. prepending an item to a vector of length N did work order N, not O(1). So, if that's true, you have the same worst case as any repeated sequence operation (which motivated Chain in the first place). building a List in the wrong order (say appending on the right) of length N is O(N^2). Similarly building a string by repeatedly adding to head or tail. I think Vector has this same problem if you append to the head (the reverse order as List).

Yes, of course, kudos to @mrdziuban for tackling this tricky problem! I just wanted to poke a bit more at the problem to see if we can find a solution which doesn't have to pay the cost of trampolining every time we invoke the resulting function but it doesn't seem to be the case.

Your suggestion of using associativity to maintain a tree would be a solution IMO. I think log depth is fine (there is no way we are going to blow the stack in a runtime we can wait for, if log N = 5000).

I think that is trickier and can actually be implemented later without breaking binary compatibility since the trampoline is private in the current implementation.

What about merging this, and then someone can make a follow up with the tree approach?

@joroKr21
Copy link
Member

joroKr21 commented Dec 23, 2021

Vector has "effectively constant" append and prepend. Think of it as a tree with branching factor 32, so more precisely it's O(log32(N)) amortised complexity. But it's not great in terms of memory overhead for a small number of elements (less than 32) for that same reason. It also has good concatenation performance with other vectors.

What about merging this, and then someone can make a follow up with the tree approach?

Sounds good to me 👍

@johnynek
Copy link
Contributor

Vector has "effectively constant" append and prepend. Think of it as a tree with branching factor 32, so more precisely it's O(log32(N)) amortised complexity. But it's not great in terms of memory overhead for a small number of elements (less than 32) for that same reason. It also has good concatenation performance with other vectors.

I know how vector works, but I thought that at least the version prior to 2.13 didn't have efficient concat. Looking this code, I don't see it: https://github.com/scala/scala/blob/2.13.x/src/library/scala/collection/immutable/Vector.scala

I don't see a specialized path that offers O(log N) vector contact. It looks to me like you have to append in a loop the smaller of the two items. Am I missing it?

@johnynek johnynek merged commit 157033b into typelevel:main Dec 23, 2021
@joroKr21
Copy link
Member

joroKr21 commented Dec 24, 2021

It's here: https://github.com/scala/scala/blob/2.13.x/src/library/scala/collection/immutable/Vector.scala#L1546-L1558
The loop kicks in when one of the sides is significantly smaller.

@mrdziuban mrdziuban deleted the fn-semigroup-stack-safety branch January 3, 2022 14:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Semigroup[Function1[I, O]] produces a function that's not stack safe
4 participants