Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance stablehlo-quant-legalize-to-tosa-rescale #2534

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,94 @@ func.func @sub(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
-> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
}

// -----
// CHECK-LABEL: @mul
func.func @mul(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>> {
// CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>}
// CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>}
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V0]], %[[V1]] : tensor<2x2xi32>
// CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array<i32: 1717986918>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 37>}
// CHECK: return %[[V3]] : tensor<2x2x!quant.uniform<i8:f32, 1.500000e-01:-1>>
%0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<2x2x!quant.uniform<i8:f32, 0.075:-1>>)
-> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-1>>
}

// -----
// CHECK-LABEL: @div
func.func @div(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>> {
// CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>}
// CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>}
// CHECK: %[[V2:.+]] = stablehlo.divide %[[V0]], %[[V1]] : tensor<2x2xi32>
// CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array<i32: 1717986918>, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array<i8: 37>}
// CHECK: return %[[V3]] : tensor<2x2x!quant.uniform<i8:f32, 1.500000e-01:-3>>
%0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>)
-> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
}

// -----
// CHECK-LABEL: @max
func.func @max(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>> {
// CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 12>}
// CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 10>}
// CHECK: %[[V2:.+]] = stablehlo.maximum %[[V0]], %[[V1]] : tensor<2x2xi32>
// CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array<i32: 1073741824>, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array<i8: 51>}
// CHECK: return %[[V3]] : tensor<2x2x!quant.uniform<i8:f32, 1.500000e-01:-3>>
%0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>)
-> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
}

// -----
// CHECK-LABEL: @min
func.func @min(%arg0 : tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>) -> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>> {
// CHECK-DAG: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 12>}
// CHECK-DAG: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 10>}
// CHECK: %[[V2:.+]] = stablehlo.minimum %[[V0]], %[[V1]] : tensor<2x2xi32>
// CHECK: %[[V3:.+]] = tosa.rescale %[[V2]] {double_round = false, input_zp = 0 : i32, multiplier = array<i32: 1073741824>, output_zp = -3 : i32, per_channel = false, scale32 = true, shift = array<i8: 51>}
// CHECK: return %[[V3]] : tensor<2x2x!quant.uniform<i8:f32, 1.500000e-01:-3>>
%0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor<2x2x!quant.uniform<i8:f32, 0.025:-1>>, tensor<2x2x!quant.uniform<i8:f32, 0.075:-2>>)
-> tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 1.5e-01:-3>>
}

// -----
// CHECK-LABEL: @abs
func.func @abs(%arg0 : tensor<20x20x!quant.uniform<i8:f32, 0.025:-1>>) -> tensor<20x20x!quant.uniform<i8:f32, 1.5e-01:-128>> {
// CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>}
// CHECK: %[[V1:.+]] = stablehlo.abs %[[V0]] : tensor<20x20xi32>
// CHECK: %[[V3:.+]] = tosa.rescale %[[V1]] {double_round = false, input_zp = 0 : i32, multiplier = array<i32: 1431655765>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i8: 33>}
// CHECK: return %[[V3]] : tensor<20x20x!quant.uniform<i8:f32, 1.500000e-01:-128>>
%0 = "stablehlo.abs"(%arg0) : (tensor<20x20x!quant.uniform<i8:f32, 0.025:-1>>) -> tensor<20x20x!quant.uniform<i8:f32, 1.5e-01:-128>>
return %0 : tensor<20x20x!quant.uniform<i8:f32, 1.5e-01:-128>>
}

// -----
// CHECK-LABEL: @compareGE
func.func @compareGE(%arg0 : tensor<20x20x!quant.uniform<i8:f32, 0.025:-1>>,
%arg1 : tensor<20x20x!quant.uniform<i8:f32, 0.075:-2>>) -> tensor<20x20xi1> {
// CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 12>}
// CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 10>}
// CHECK: %[[V2:.+]] = stablehlo.compare GE, %[[V0]], %[[V1]], TOTALORDER :
// CHECK: return %[[V2]]
%0 = stablehlo.compare GE, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform<i8:f32, 0.025:-1>>, tensor<20x20x!quant.uniform<i8:f32, 0.075:-2>>) -> tensor<20x20xi1>
return %0 : tensor<20x20xi1>
}

// -----
// CHECK-LABEL: @compareLT
func.func @compareLT(%arg0 : tensor<20x20x!quant.uniform<i16:f32, 0.025:-1>>,
%arg1 : tensor<20x20x!quant.uniform<i16:f32, 0.075:-2>>) -> tensor<20x20xi1> {
// CHECK: %[[V0:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array<i32: 1431655765>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 17>}
// CHECK: %[[V1:.+]] = tosa.rescale %arg1 {double_round = false, input_zp = -2 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 15>}
// CHECK: %[[V2:.+]] = stablehlo.compare LT, %[[V0]], %[[V1]], TOTALORDER :
// CHECK: return %[[V2]]
%0 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<20x20x!quant.uniform<i16:f32, 0.025:-1>>, tensor<20x20x!quant.uniform<i16:f32, 0.075:-2>>) -> tensor<20x20xi1>
return %0 : tensor<20x20xi1>
}
Loading
Loading