Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Constraints in the Rule Structs #205

Open
avik-pal opened this issue Dec 23, 2024 · 5 comments
Open

Type Constraints in the Rule Structs #205

avik-pal opened this issue Dec 23, 2024 · 5 comments

Comments

@avik-pal
Copy link
Member

I was working on getting Optimisers to work with reactant, and it (mostly) does, but one of the current issues is that eta is forced to be Float64 in some of the structs.

But consider the following IR:

julia> pss = (; a = (rand(3) |> Reactant.to_rarray))
(a = ConcreteRArray{Float64, 1}([0.023549009580651203, 0.10813549621409191, 0.7874517465499301]),)

julia> st_opt = @allowscalar Optimisers.setup(opt, pss)
(a = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (ConcreteRArray{Float64, 1}([0.0, 0.0, 0.0]), ConcreteRArray{Float64, 1}([0.0, 0.0, 0.0]), (0.9, 0.999))),)

julia> @code_hlo Optimisers.update(st_opt, pss, pss)
module {
  func.func @main(%arg0: tensor<3xf64>, %arg1: tensor<3xf64>, %arg2: tensor<3xf64>) -> (tensor<3xf64>, tensor<3xf64>, tensor<3xf64>) {
    %cst = stablehlo.constant dense<1.000000e-03> : tensor<3xf64>
    %cst_0 = stablehlo.constant dense<1.000000e-08> : tensor<3xf64>
    %cst_1 = stablehlo.constant dense<0.0010000000000000009> : tensor<3xf64>
    %cst_2 = stablehlo.constant dense<0.99899999999999999> : tensor<3xf64>
    %cst_3 = stablehlo.constant dense<0.099999999999999978> : tensor<3xf64>
    %cst_4 = stablehlo.constant dense<9.000000e-01> : tensor<3xf64>
    %0 = stablehlo.multiply %cst_4, %arg0 : tensor<3xf64>
    %1 = stablehlo.multiply %cst_3, %arg2 : tensor<3xf64>
    %2 = stablehlo.add %0, %1 : tensor<3xf64>
    %3 = stablehlo.multiply %cst_2, %arg1 : tensor<3xf64>
    %4 = stablehlo.abs %arg2 : tensor<3xf64>
    %5 = stablehlo.multiply %4, %4 : tensor<3xf64>
    %6 = stablehlo.multiply %cst_1, %5 : tensor<3xf64>
    %7 = stablehlo.add %3, %6 : tensor<3xf64>
    %8 = stablehlo.divide %2, %cst_3 : tensor<3xf64>
    %9 = stablehlo.divide %7, %cst_1 : tensor<3xf64>
    %10 = stablehlo.sqrt %9 : tensor<3xf64>
    %11 = stablehlo.add %10, %cst_0 : tensor<3xf64>
    %12 = stablehlo.divide %8, %11 : tensor<3xf64>
    %13 = stablehlo.multiply %12, %cst : tensor<3xf64>
    %14 = stablehlo.subtract %arg2, %13 : tensor<3xf64>
    return %2, %7, %14 : tensor<3xf64>, tensor<3xf64>, tensor<3xf64>
  }
}

While this looks correct, if you take a closer look:

%cst_1 = stablehlo.constant dense<0.0010000000000000009> : tensor<3xf64>
%cst_2 = stablehlo.constant dense<0.99899999999999999> : tensor<3xf64>
%cst_3 = stablehlo.constant dense<0.099999999999999978> : tensor<3xf64>

The learning rate and the other parameters get embedded into the IR as constants. So even if we do an adjust of the learning rate, it will still be using the old learning rate

@ToucheSir
Copy link
Member

Although I don't remember the specific beats of that discussion, this choice to standardize eltypes was discussed at length over multiple issues. #151 is a good jumping-off point with links.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 6, 2025

@mcabbott what do you think we should do here?

Now we have

"""
    @def struct Rule; eta = 0.1; beta = (0.7, 0.8); end

Helper macro for defining rules with default values.
The types of the literal values are used in the `struct`,
like this:

struct Rule
  eta::Float64
  beta::Tuple{Float64, Float64}
  Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
  Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end

Any field called `eta` is assumed to be a learning rate, and cannot be negative.
"""

Should it become

struct Rule{T1,T2}
  eta::T1
  beta::T2
  Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
  Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end

? (and possibly add a nice show method avoiding)

@wsmoses
Copy link
Contributor

wsmoses commented Jan 9, 2025

That looks good to me, if @mcabbott is game you should go for it!

Maybe it's worth making a draft PR and seeing what breaks?

@mcabbott
Copy link
Member

mcabbott commented Jan 9, 2025

We dumped type parameters partly because it was annoying that Adam(0.1f0) != Adam(0.1) despite there being no difference in intent. If we re-introduce them, IMO we should do so in a way that preserves this, perhaps by having the constructor allow weird types (like tracked ones) but not Float32?

The sketch above doesn't run, but looks like it wants to allow any different type for β... perhaps the minimal thing is just one type parameter {T <: Real} for all the numbers?

@wsmoses
Copy link
Contributor

wsmoses commented Jan 9, 2025

So T <: Real will mess with us a bit, since we have

abstract type RNumber{T<:ReactantPrimitive} <: Number end
mutable struct TracedRNumber{T} <: RNumber{T}

Unfortunately I don't know a good way in Julia to define Traced{T} <: Real iff T <: Real

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants