Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use the model parameters trained from Flux.jl in the SimpleChains.jl model?? #159

Open
manozzing opened this issue Jan 26, 2024 · 0 comments

Comments

@manozzing
Copy link

Hello,

I'm quite new to this package (I think this is really cool!) and wondering if I can copy the Flux model parameters to SimpleChains model. For your reference, I'm developing a machine-learned advection solver and I'm using neural network to estimate numerical flux in each time stepping. Here is the code I'm working on as an example

Flux.jl model

flux_Flux = Chain( Flux.Conv((3,1), 2 => 10, pad = 0, Flux.relu),
        Flux.Conv((3,1), 10 => 10, pad = 0, Flux.relu),
        Flux.Conv((3,1), 10 => 2, pad = 0, Flux.identity)) 
loss(x, y) = Flux.Losses.mae(flux_Flux(x), y)
ps = Flux.params(flux_Flux)

SimpleChains.jl model

flux_estimator = SimpleChain(
    SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
    SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
    SimpleChains.Conv(SimpleChains.identity, (3, 1), 2)
)
p = SimpleChains.init_params(flux_estimator, size(input))

I confirmed that both models have the same structure with same number of parameters.

Chain(
  Conv((3, 1), 2 => 10, relu),          # 70 parameters
  Conv((3, 1), 10 => 10, relu),         # 310 parameters
  Conv((3, 1), 10 => 2),                # 62 parameters
)                   # Total: 6 arrays, 442 parameters, 2.781 KiB.
442-element StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}}:

After training the Flux model I passed the model parameters to SimpleChains parameter as below

p[1:60] = Flux.params(flux_Flux)[1][1:60]
p[61:70] = Flux.params(flux_Flux)[2][1:10]
p[71:370] = Flux.params(flux_Flux)[3][1:300]
p[371:380] = Flux.params(flux_Flux)[4][1:10]
p[381:440] = Flux.params(flux_Flux)[5][1:60]
p[441:442] = Flux.params(flux_Flux)[6][1:2]

However, they gave me totally different results when I feed the same input dataset.

Flux estimation in several time steps using Flux.jl model
flux_scattered_Flux

Flux estimation in several time steps using SimpleChains.jl model with the same parameters as the Flux.jl model
flux_scattered_SC

Do you have any idea why SimpleChains.jl model gave me very different results? I originally tried to train my model with SimpleChains.jl but likewise the model training was not successful so I chose to pass the parameters from Flux.jl and here I saw it's not very helpful so far. Any comments will help me out. Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant