Skip to content

Commit

Permalink
implement generator expressions (#4470)
Browse files Browse the repository at this point in the history
This introduces the types `Generator`, which maps a function over
an iterator, and `IteratorND`, which wraps an iterator with a
shape tuple.
  • Loading branch information
JeffBezanson committed Feb 18, 2016
1 parent 8f26020 commit 822a4a0
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 21 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ CORE_SRCS := $(addprefix $(JULIAHOME)/, \
base/dict.jl \
base/error.jl \
base/essentials.jl \
base/generator.jl \
base/expr.jl \
base/functors.jl \
base/hashing.jl \
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Julia v0.5.0 Release Notes
New language features
---------------------

* Generator expressions, e.g. `f(i) for i in 1:n` (#4470). This returns an iterator
that computes the specified values on demand.

* Macro expander functions are now generic, so macros can have multiple definitions
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
However note that the argument types refer to the syntax tree representation, and not
Expand Down
2 changes: 1 addition & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ function map!{F}(f::F, dest::AbstractArray, A::AbstractArray)
return dest
end

function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A::AbstractArray)
function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A)
# map to dest array, checking the type of each result. if a result does not
# match, widen the result type and re-dispatch.
i = offs
Expand Down
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ unsafe_convert{T}(::Type{T}, x::T) = x
(::Type{Array{T}}){T}(m::Int, n::Int, o::Int) = Array{T,3}(m, n, o)

# TODO: possibly turn these into deprecations
Array{T,N}(::Type{T}, d::NTuple{N,Int}) = Array{T}(d)
Array{T}(::Type{T}, d::Int...) = Array{T}(d)
Array{T}(::Type{T}, m::Int) = Array{T,1}(m)
Array{T}(::Type{T}, m::Int,n::Int) = Array{T,2}(m,n)
Expand Down
1 change: 1 addition & 0 deletions base/coreimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ macro doc(str, def) Expr(:escape, def) end

## Load essential files and libraries
include("essentials.jl")
include("generator.jl")
include("reflection.jl")
include("options.jl")

Expand Down
21 changes: 21 additions & 0 deletions base/generator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Generator(f, iter)
Given a function `f` and an iterator `iter`, construct an iterator that yields
the values of `f` applied to the elements of `iter`.
The syntax `f(x) for x in iter` is syntax for constructing an instance of this
type.
"""
immutable Generator{I,F}
f::F
iter::I
end

start(g::Generator) = start(g.iter)
done(g::Generator, s) = done(g.iter, s)
function next(g::Generator, s)
v, s2 = next(g.iter, s)
g.f(v), s2
end

collect(g::Generator) = map(g.f, g.iter)
48 changes: 48 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,51 @@ eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))
x = prod_next(p, st)
((x[1][1],x[1][2]...), x[2])
end

_size(p::Prod2) = (length(p.a), length(p.b))
_size(p::Prod) = (length(p.a), _size(p.b)...)

"""
IteratorND(iter, dims)
Given an iterator `iter` and dimensions tuple `dims`, return an iterator that
yields the same values as `iter`, but with the specified multi-dimensional shape.
For example, this determines the shape of the array returned when `collect` is
applied to this iterator.
"""
immutable IteratorND{I,N}
iter::I
dims::NTuple{N,Int}

function (::Type{IteratorND}){I,N}(iter::I, shape::NTuple{N,Integer})
if length(iter) != prod(shape)
throw(DimensionMismatch("dimensions $shape must be consistent with iterator length $(iter(a))"))
end
new{I,N}(iter, shape)
end
(::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, _size(p))
end

start(i::IteratorND) = start(i.iter)
done(i::IteratorND, s) = done(i.iter, s)
next(i::IteratorND, s) = next(i.iter, s)

size(i::IteratorND) = i.dims
length(i::IteratorND) = prod(size(i))
ndims{I,N}(::IteratorND{I,N}) = N

eltype{I}(::IteratorND{I}) = eltype(I)

collect(i::IteratorND) = copy!(Array(eltype(i),size(i)), i)

function collect{I<:IteratorND}(g::Generator{I})
sz = size(g.iter)
if length(g.iter) == 0
return Array(Union{}, sz)
end
st = start(g)
first, st = next(g, st)
dest = Array(typeof(first), sz)
dest[1] = first
return map_to!(g.f, 2, st, dest, g.iter)
end
1 change: 1 addition & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
include("essentials.jl")
include("docs/bootstrap.jl")
include("base.jl")
include("generator.jl")
include("reflection.jl")
include("options.jl")

Expand Down
31 changes: 31 additions & 0 deletions doc/manual/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,37 @@ that the result is of type ``Float64`` by writing::

Float64[ 0.25*x[i-1] + 0.5*x[i] + 0.25*x[i+1] for i=2:length(x)-1 ]

.. _man-generator-expressions:

Generator Expressions
---------------------

Comprehensions can also be written without the enclosing square brackets, producing
an object known as a generator. This object can be iterated to produce values on
demand, instead of allocating an array and storing them in advance
(see :ref:`_man-interfaces-iteration`).
For example, the following expression sums a series without allocating memory::

julia> sum(1/n^2 for n=1:1000)
1.6439345666815615

When writing a generator expression with multiple dimensions, it needs to be
enclosed in parentheses to avoid ambiguity::

julia> collect(1/(i+j) for i=1:2, j=1:2)
ERROR: function collect does not accept keyword arguments

In this call, the range ``j=1:2`` was interpreted as a second argument to
``collect``. This is fixed by adding parentheses::

julia> collect((1/(i+j) for i=1:2, j=1:2))
2x2 Array{Float64,2}:
0.5 0.333333
0.333333 0.25

Note that ``collect`` gathers the values produced by an iterator into an array,
giving the same effect as an array comprehension.

.. _man-array-indexing:

Indexing
Expand Down
53 changes: 33 additions & 20 deletions src/julia-parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1401,21 +1401,22 @@
(parse-comma-separated s parse-eq*))

;; as above, but allows both "i=r" and "i in r"
(define (parse-iteration-spec s)
(let ((r (parse-eq* s)))
(cond ((and (pair? r) (eq? (car r) '=)) r)
((eq? r ':) r)
((and (length= r 4) (eq? (car r) 'comparison)
(or (eq? (caddr r) 'in) (eq? (caddr r) '∈)))
`(= ,(cadr r) ,(cadddr r)))
(else
(error "invalid iteration specification")))))

(define (parse-comma-separated-iters s)
(let loop ((ranges '()))
(let ((r (parse-eq* s)))
(let ((r (cond ((and (pair? r) (eq? (car r) '=))
r)
((eq? r ':)
r)
((and (length= r 4) (eq? (car r) 'comparison)
(or (eq? (caddr r) 'in) (eq? (caddr r) '∈)))
`(= ,(cadr r) ,(cadddr r)))
(else
(error "invalid iteration specification")))))
(case (peek-token s)
((#\,) (take-token s) (loop (cons r ranges)))
(else (reverse! (cons r ranges))))))))
(let ((r (parse-iteration-spec s)))
(case (peek-token s)
((#\,) (take-token s) (loop (cons r ranges)))
(else (reverse! (cons r ranges)))))))

(define (parse-space-separated-exprs s)
(with-space-sensitive
Expand Down Expand Up @@ -1471,6 +1472,12 @@
(loop (cons nxt lst)))
((eqv? c #\;) (loop (cons nxt lst)))
((eqv? c closer) (loop (cons nxt lst)))
((eq? c 'for)
(take-token s)
(let ((gen (parse-generator s nxt #f)))
(if (eqv? (require-token s) #\,)
(take-token s))
(loop (cons gen lst))))
;; newline character isn't detectable here
#;((eqv? c #\newline)
(error "unexpected line break in argument list"))
Expand Down Expand Up @@ -1515,7 +1522,7 @@
(define (parse-comprehension s first closer)
(let ((r (parse-comma-separated-iters s)))
(if (not (eqv? (require-token s) closer))
(error (string "expected " closer))
(error (string "expected \"" closer "\""))
(take-token s))
`(comprehension ,first ,@r)))

Expand All @@ -1525,12 +1532,11 @@
`(dict_comprehension ,@(cdr c))
(error "invalid dict comprehension"))))

(define (parse-generator s first closer)
(let ((r (parse-comma-separated-iters s)))
(if (not (eqv? (require-token s) closer))
(error (string "expected " closer))
(take-token s))
`(macrocall @generator ,first ,@r)))
(define (parse-generator s first allow-comma)
(let ((r (if allow-comma
(parse-comma-separated-iters s)
(list (parse-iteration-spec s)))))
`(generator ,first ,@r)))

(define (parse-matrix s first closer gotnewline)
(define (fix head v) (cons head (reverse v)))
Expand Down Expand Up @@ -1960,6 +1966,13 @@
`(tuple ,ex)
;; value in parentheses (x)
ex))
((eq? t 'for)
(take-token s)
(let ((gen (parse-generator s ex #t)))
(if (eqv? (require-token s) #\) )
(take-token s)
(error "expected \")\""))
gen))
(else
;; tuple (x,) (x,y) (x...) etc.
(if (eqv? t #\, )
Expand Down
17 changes: 17 additions & 0 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,23 @@
(lower-ccall name RT (cdr argtypes) args))))
e))

'generator
(lambda (e)
(let ((expr (cadr e))
(vars (map cadr (cddr e)))
(ranges (map caddr (cddr e))))
(let* ((argname (if (and (length= vars 1) (symbol? (car vars)))
(car vars)
(gensy)))
(splat (if (eq? argname (car vars))
'()
`((= (tuple ,@vars) ,argname)))))
(expand-forms
`(call (top Generator) (-> ,argname (block ,@splat ,expr))
,(if (length= ranges 1)
(car ranges)
`(call (top IteratorND) (call (top product) ,@ranges))))))))

'comprehension
(lambda (e)
(expand-forms (lower-comprehension #f (cadr e) (cddr e))))
Expand Down
20 changes: 20 additions & 0 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,23 @@ let
foreach((args...)->push!(a,args), [2,4,6], [10,20,30])
@test a == [(2,10),(4,20),(6,30)]
end

# generators (#4470, #14848)

@test sum(i/2 for i=1:2) == 1.5
@test collect(2i for i=2:5) == [4,6,8,10]
@test collect((i+10j for i=1:2,j=3:4)) == [31 41; 32 42]
@test collect((i+10j for i=1:2,j=3:4,k=1:1)) == reshape([31 41; 32 42], (2,2,1))

let I = Base.IteratorND(1:27,(3,3,3))
@test collect(I) == reshape(1:27,(3,3,3))
@test size(I) == (3,3,3)
@test length(I) == 27
@test eltype(I) === Int
@test ndims(I) == 3
end

let A = collect(Base.Generator(x->2x, Real[1.5,2.5]))
@test A == [3,5]
@test isa(A,Vector{Float64})
end

0 comments on commit 822a4a0

Please sign in to comment.