Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial evaluation of scan computation evaluates loop counter at compile time #3108

Closed
hawkinsp opened this issue May 15, 2020 · 0 comments · Fixed by #4038
Closed

Partial evaluation of scan computation evaluates loop counter at compile time #3108

hawkinsp opened this issue May 15, 2020 · 0 comments · Fixed by #4038
Assignees
Labels
bug Something isn't working

Comments

@hawkinsp
Copy link
Collaborator

Repro, derived from #3076

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np


@jax.jit
def _polyval(p, x):
  shape = lax.broadcast_shapes(p.shape[1:], x.shape)
  dtype = jnp.result_type(p, x)
  y = lax.full_like(x, 0, shape=shape, dtype=dtype)
  y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p)
  return y


def polyval(p, x):
  return _polyval(p, x)


x = np.random.rand()
p = np.random.randn(10000)
pv = jax.jit(polyval)
pv(p, x).block_until_ready()
pv(p, x).block_until_ready()

This code appears to evaluate 3 loops: one at compile time, and one for each of the pv instances. The compile-time loop appears pointless and wasteful. After optimization it boils down to:

HloModule primitive_computation_while.37

%body_computation.6.clone (parameter.5: (s32[], pred[])) -> (s32[], pred[]) {
  %parameter.5 = (s32[], pred[]) parameter(0)
  %get-tuple-element.14 = s32[] get-tuple-element((s32[], pred[]) %parameter.5), index=0
  %constant_1 = s32[] constant(1), metadata={op_type="add" op_name="while/body/add"}
  %add = s32[] add(s32[] %get-tuple-element.14, s32[] %constant_1), metadata={op_type="add" op_name="while/body/add"}
  %constant_2 = pred[] constant(false)
  %copy.8 = pred[] copy(pred[] %constant_2)
  ROOT %tuple.6 = (s32[], pred[]) tuple(s32[] %add, pred[] %copy.8)
}

%cond_computation.22.clone (parameter.0: (s32[], pred[])) -> pred[] {
  %parameter.0 = (s32[], pred[]) parameter(0)
  %get-tuple-element.2 = s32[] get-tuple-element((s32[], pred[]) %parameter.0), index=0
  %constant = s32[] constant(10000), metadata={op_type="lt" op_name="while/cond/lt"}
  ROOT %compare = pred[] compare(s32[] %get-tuple-element.2, s32[] %constant), direction=LT, metadata={op_type="lt" op_name="while/cond/lt"}
}

ENTRY %primitive_computation_while.37 (parameter.1: pred[], parameter.2: pred[], parameter.3: s32[], parameter.4: pred[]) -> (s32[], pred[]) {
  %parameter.1 = pred[] parameter(0), metadata={op_type="while" op_name="while[ body_nconsts=2\n       cond_nconsts=0 ]"}
  %parameter.2 = pred[] parameter(1), metadata={op_type="while" op_name="while[ body_nconsts=2\n       cond_nconsts=0 ]"}
  %parameter.3 = s32[] parameter(2), metadata={op_type="while" op_name="while[ body_nconsts=2\n       cond_nconsts=0 ]"}
  %copy.6 = s32[] copy(s32[] %parameter.3)
  %parameter.4 = pred[] parameter(3), metadata={op_type="while" op_name="while[ body_nconsts=2\n       cond_nconsts=0 ]"}
  %copy.7 = pred[] copy(pred[] %parameter.4)
  %tuple.3 = (s32[], pred[]) tuple(s32[] %copy.6, pred[] %copy.7)
  ROOT %while = (s32[], pred[]) while((s32[], pred[]) %tuple.3), condition=%cond_computation.22.clone, body=%body_computation.6.clone, metadata={op_type="while" op_name="while[ body_nconsts=2\n       cond_nconsts=0 ]"}
}

which is a loop that roughly paraphrases in Python as:

while i < 1000:
  i += 1

We believe an upcoming change to how jit computations are staged out of Python will fix this.

@hawkinsp hawkinsp added the bug Something isn't working label May 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants