Skip to content

Commit

Permalink
update flux compat for test/
Browse files Browse the repository at this point in the history
  • Loading branch information
zuhengxu committed Aug 22, 2023
1 parent 0ced915 commit 45c8e0a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
23 changes: 23 additions & 0 deletions example/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using CUDA
using LinearAlgebra
using Distributions, Random
using Bijectors
using Flux
import NormalizingFlows as NF

CUDA.functional()
rng = CUDA.default_rng()
T = Float32
q0_g = MvNormal(CUDA.zeros(T, 2), I)
# construct gpu flow
ts = reduce(, [f32(Bijectors.PlanarLayer(2)) for _ in 1:2])
ts_g = gpu(ts)
flow_g = transformed(q0_g, ts_g)

# sample from GPU MvNormal
x = NF.rand_device(rng, q0_g) # good
xs = NF.rand_device(rng, q0_g, 100) # ambiguous

# sample from GPU flow
y = NF.rand_device(rng, flow_g) # ambiguous
ys = NF.rand_device(rng, flow_g, 100) # ambiguous
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ ForwardDiff = "0.10.25"
Optimisers = "0.2.16"
ReverseDiff = "1.14"
Zygote = "0.6"
Flux = "0.14"
Flux = "0.13, 0.14"

0 comments on commit 45c8e0a

Please sign in to comment.