Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy sublanguage #1668

Merged
merged 1 commit into from
Jan 8, 2020
Merged

Lazy sublanguage #1668

merged 1 commit into from
Jan 8, 2020

Commits on Jan 8, 2020

  1. implement lazy sublanguage

    Before this commit, this computation would avoid materializing the iota
    array at trace time:
    
      @jit
      def f(x):
        m, n = x.shape
        return x + np.arange(n)
    
    But this one would materialize the iota array at trace time and stage it
    into the computation as a potentially large array constant:
    
      @jit
      def f(x):
        m, n = x.shape
        return x + np.arange(m)[:, None]
    
    The difference is that previously operations like broadcasts,
    transposes, and reshapes that add singleton dimensions (as above) would
    force otherwise lazy values to be materialized, while after this commit
    broadcasts, transposes, and reshapes are all lazy operations that only
    update metadata on their input rather than compiling and executing XLA
    computations and producing new buffers.
    
    Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
    
    This commit replaces the ad-hoc "lazy device constant" system, which was
    used to get the simpler behavior in the first example above.
    
    Incidentally fixes #1431
    
    See #1668 for more.
    mattjj committed Jan 8, 2020
    Configuration menu
    Copy the full SHA
    cd54c67 View commit details
    Browse the repository at this point in the history