Skip to content

Commit

Permalink
Cleanup; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanfilyonok committed Dec 23, 2021
1 parent c1bcea8 commit 7aea127
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 292 deletions.
54 changes: 0 additions & 54 deletions src/Brahma.FSharp.OpenCL.Core/ClContext.fs
Original file line number Diff line number Diff line change
Expand Up @@ -148,57 +148,3 @@ type ClContext private (context: Context, device: Device, translator: FSQuotatio
AllocationMode = allocationMode
}
)

// type RuntimeContext =
// {
// ClContext: ClContext
// CommandQueue: MailboxProcessor<Msg>
// PreferedWGSize: int
// }

// member this.WithNewCommandQueue() =
// { this with CommandQueue = CommandQueueProvider.CreateQueue(this.ClContext) }

// member this.CreateClBuffer
// (
// data: 'a[],
// ?hostAccessMode: HostAccessMode,
// ?deviceAccessMode: DeviceAccessMode,
// ?allocationMode: AllocationMode
// ) =

// let hostAccessMode = defaultArg hostAccessMode ClMemFlags.DefaultIfData.HostAccessMode
// let deviceAccessMode = defaultArg deviceAccessMode ClMemFlags.DefaultIfData.DeviceAccessMode
// let allocationMode = defaultArg allocationMode ClMemFlags.DefaultIfData.AllocationMode

// new ClBuffer<'a>(
// this.ClContext,
// Data data,
// {
// HostAccessMode = hostAccessMode
// DeviceAccessMode = deviceAccessMode
// AllocationMode = allocationMode
// }
// )

// member this.CreateClBuffer
// (
// size: int,
// ?hostAccessMode: HostAccessMode,
// ?deviceAccessMode: DeviceAccessMode,
// ?allocationMode: AllocationMode
// ) =

// let hostAccessMode = defaultArg hostAccessMode ClMemFlags.DefaultIfNoData.HostAccessMode
// let deviceAccessMode = defaultArg deviceAccessMode ClMemFlags.DefaultIfNoData.DeviceAccessMode
// let allocationMode = defaultArg allocationMode ClMemFlags.DefaultIfNoData.AllocationMode

// new ClBuffer<'a>(
// this.ClContext,
// Size size,
// {
// HostAccessMode = hostAccessMode
// DeviceAccessMode = deviceAccessMode
// AllocationMode = allocationMode
// }
// )
2 changes: 1 addition & 1 deletion src/Brahma.FSharp.OpenCL.Core/ClProgram.fs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type ClProgram<'TRange, 'a when 'TRange :> INDRange>

member this.Code = clCode

member this.NewKernel(?kernelName) =
member this.GetKernel(?kernelName) =
let kernelName = defaultArg kernelName "brahmaKernel"
let kernel = createKernel program kernelName

Expand Down
2 changes: 1 addition & 1 deletion src/Brahma.FSharp.OpenCL.Core/ClTask.fs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ module ClTaskOpened =
opencl {
let! ctx = ClTask.ask

let kernel = ctx.CreateClProgram(command).NewKernel()
let kernel = ctx.CreateClProgram(command).GetKernel()

ctx.CommandQueue.Post <| MsgSetArguments(fun () -> binder kernel.KernelFunc)
ctx.CommandQueue.Post <| Msg.CreateRunMsg<_, _>(kernel)
Expand Down
2 changes: 1 addition & 1 deletion src/Brahma.FSharp.OpenCL.Core/IKernel.fs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ open OpenCL.Net
type IKernel<'TRange, 'a when 'TRange :> INDRange> =
abstract Kernel : Kernel
abstract NDRange : INDRange
// dont sure about naming
// not sure about naming
abstract KernelFunc : ('TRange -> 'a)
abstract ReleaseInternalBuffers : unit -> unit
4 changes: 2 additions & 2 deletions src/Brahma.FSharp.OpenCL.Shared/KernelLangExtensions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ module KernelLangExtensions =
failIfOutsideKernel ()
ignore null

let barrierGlobal =
let barrierGlobal () =
failIfOutsideKernel ()
ignore null

let barrierFull =
let barrierFull () =
failIfOutsideKernel ()
ignore null

Expand Down
84 changes: 22 additions & 62 deletions src/Brahma.FSharp.OpenCL.Translator/Body.fs
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,7 @@ module rec Body =
do! State.modify (fun context -> context.VarDecls.Clear(); context)

for expr in linearized do
// NOTE с тим не работает, хотя до этого работало
// непонятно зачем это вообще нужно
// NOTE тут что то сломалось :(
// do! State.modify (fun context -> context.VarDecls.Clear(); context)
match! translate expr with
| :? StatementBlock<Lang> as s1 ->
Expand Down Expand Up @@ -522,9 +521,9 @@ module rec Body =

let! unionValueExpr = translateAsExpr expr

// NOTE для опшна классы наследники не создаются, поэтому не работает
let caseName = propInfo.DeclaringType.Name
let unionCaseField =
// для option классы наследники не создаются, поэтому нужно обрабатывать отдельно
if caseName <> "FSharpOption`1" then
unionType.GetCaseByName caseName
else
Expand Down Expand Up @@ -678,9 +677,7 @@ module rec Body =
)
)
) ->

let! r = translateForIntegerRangeLoop loopVar start finish (Some step) loopBody
return r :> Node<_>
return! translateForIntegerRangeLoop loopVar start finish (Some step) loopBody >>= toNode

| Patterns.Let
(
Expand All @@ -706,41 +703,24 @@ module rec Body =
)
)
) ->
return! translateForIntegerRangeLoop loopVar start finish None loopBody >>= toNode

let! r = translateForIntegerRangeLoop loopVar start finish (None) loopBody
return r :> Node<_>

// | DerivedPatterns.SpecificCall <@ unbox @>
// (
// _,
// _,
// [Patterns.Value (boxed, type')]
// ) ->
// let! r = translateValue (Expr.Value <| unbox boxed) type'
// return r :> Node<_>

| Patterns.Call (exprOpt, mInfo, args) ->
let! r = translateCall exprOpt mInfo args
return r :> Node<_>
| Patterns.Call (exprOpt, mInfo, args) -> return! translateCall exprOpt mInfo args >>= toNode
| Patterns.Coerce (expr, sType) -> return raise <| InvalidKernelException(sprintf "Coerce is not suported: %O" expr)
| Patterns.DefaultValue sType -> return raise <| InvalidKernelException(sprintf "DefaulValue is not suported: %O" expr)

| Patterns.FieldGet (exprOpt, fldInfo) ->
match exprOpt with
| Some expr -> return! translateStructFieldGet expr fldInfo.Name >>= toNode
| None -> return raise <| InvalidKernelException(sprintf "FieldGet for empty host is not suported. Field: %A" fldInfo.Name)

| Patterns.FieldSet (exprOpt, fldInfo, expr) ->
match exprOpt with
| Some e ->
let! r = translateFieldSet e fldInfo.Name expr
return r :> Node<_>
| Some e -> return! translateFieldSet e fldInfo.Name expr >>= toNode
| None -> return raise <| InvalidKernelException(sprintf "Fileld set with empty host is not supported. Field: %A" fldInfo)
| Patterns.ForIntegerRangeLoop (i, from', to', body) ->
let! r = translateForIntegerRangeLoop i from' to' None body
return r :> Node<_>
| Patterns.IfThenElse (cond, thenExpr, elseExpr) ->
return!
translateIf cond thenExpr elseExpr
|> State.map (fun x -> x :> Node<_>)

| Patterns.ForIntegerRangeLoop (i, from', to', body) -> return! translateForIntegerRangeLoop i from' to' None body >>= toNode
| Patterns.IfThenElse (cond, thenExpr, elseExpr) -> return! translateIf cond thenExpr elseExpr >>= toNode

| Patterns.Lambda (var, _expr) -> return raise <| InvalidKernelException(sprintf "Lambda is not suported: %A" expr)
| Patterns.Let (var, expr, inExpr) ->
Expand Down Expand Up @@ -792,20 +772,12 @@ module rec Body =

return NewStruct(unionInfo, tag :: args) :> Node<_>

| Patterns.PropertyGet (exprOpt, propInfo, exprs) ->
let! res = translatePropGet exprOpt propInfo exprs
return res :> Node<_>
| Patterns.PropertySet (exprOpt, propInfo, exprs, expr) ->
let! res = translatePropSet exprOpt propInfo exprs expr
return res :> Node<_>
| Patterns.Sequential (expr1, expr2) ->
let! res = translateSeq expr1 expr2
return res :> Node<_>
| Patterns.PropertyGet (exprOpt, propInfo, exprs) -> return! translatePropGet exprOpt propInfo exprs >>= toNode
| Patterns.PropertySet (exprOpt, propInfo, exprs, expr) -> return! translatePropSet exprOpt propInfo exprs expr >>= toNode
| Patterns.Sequential (expr1, expr2) -> return! translateSeq expr1 expr2 >>= toNode
| Patterns.TryFinally (tryExpr, finallyExpr) -> return raise <| InvalidKernelException(sprintf "TryFinally is not suported: %O" expr)
| Patterns.TryWith (expr1, var1, expr2, var2, expr3) -> return raise <| InvalidKernelException(sprintf "TryWith is not suported: %O" expr)
| Patterns.TupleGet (expr, i) ->
let! r = translateStructFieldGet expr ("_" + (string (i + 1)))
return r :> Node<_>
| Patterns.TupleGet (expr, i) -> return! translateStructFieldGet expr ("_" + (string (i + 1))) >>= toNode
| Patterns.TypeTest (expr, sType) -> return raise <| InvalidKernelException(sprintf "TypeTest is not suported: %O" expr)

| Patterns.UnionCaseTest (expr, unionCaseInfo) ->
Expand All @@ -829,25 +801,13 @@ module rec Body =
VarDecl(res.Type, name, Some(res :> Expression<_>), AddressSpaceQualifier.Constant)
)
let var = Var(name, sType)
let! res = translateVar var
return res :> Node<_>
return! translateVar var >>= toNode
else
let! res = translateValue obj' sType
return res :> Node<_>
| Patterns.Value (obj', sType) ->
// если оборачиваем бокс значение в валью, то тип будет obj
// вот если у нас <@ unbox<int> (box 6) @> то тип такого выражения int, хотя box 6 все еще 6, так что хзкак это делать
//printfn "%A" sType
let! res = translateValue obj' sType
return res :> Node<_>
| Patterns.Var var ->
let! res = translateVar var
return res :> Node<_>
| Patterns.VarSet (var, expr) ->
let! res = translateVarSet var expr
return res :> Node<_>
| Patterns.WhileLoop (condExpr, bodyExpr) ->
let! r = translateWhileLoop condExpr bodyExpr
return r :> Node<_>
return! translateValue obj' sType >>= toNode

| Patterns.Value (obj', sType) -> return! translateValue obj' sType >>= toNode
| Patterns.Var var -> return! translateVar var >>= toNode
| Patterns.VarSet (var, expr) -> return! translateVarSet var expr >>= toNode
| Patterns.WhileLoop (condExpr, bodyExpr) -> return! translateWhileLoop condExpr bodyExpr >>= toNode
| _ -> return raise <| InvalidKernelException(sprintf "Folowing expression inside kernel is not supported:\n%O" expr)
}
7 changes: 6 additions & 1 deletion src/Brahma.FSharp.OpenCL.Translator/Exceptions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ type InvalidKernelException =
new(message: string, inner: Exception) = { inherit Exception(message, inner) } //

/// The exception that is thrown when the unexpected error occured during the translation.
exception TranslationFailedException of string
type TranslationFailedException =
inherit Exception

new() = { inherit Exception() } //
new(message: string) = { inherit Exception(message) }
new(message: string, inner: Exception) = { inherit Exception(message, inner) }
4 changes: 2 additions & 2 deletions src/Brahma.FSharp.OpenCL.Translator/Translator.fs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type FSQuotationToOpenCLTranslator(translatorOptions: TranslatorOptions) =
| DerivedPatterns.Applications
(
Patterns.Var funcVar,
AtomicApplArgs (mutex, volatileVar)
AtomicApplArgs (_, volatileVar)
)
when funcVar.Name.StartsWith "atomic" ->

if kernelArgumentsNames |> List.contains volatileVar.Name then
atomicPointerArgQualifiers.Add(funcVar, Global)
elif localVarsNames |> List.contains volatileVar .Name then
elif localVarsNames |> List.contains volatileVar.Name then
atomicPointerArgQualifiers.Add(funcVar, Local)
else
failwith "Atomic pointer argument should be from local or global memory only"
Expand Down
2 changes: 2 additions & 0 deletions tests/Brahma.FSharp.Tests/AtomicTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ open ExpectoFsCheck
open FsCheck
open Brahma.FSharp.Tests

// TODO add tests in inc dec on supported types (generate spinlock)

let logger = Log.create "AtomicTests"

type NormalizedFloatArray =
Expand Down
42 changes: 37 additions & 5 deletions tests/Brahma.FSharp.Tests/CompositeTypesTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ open Expecto
open Brahma.FSharp.OpenCL
open FSharp.Quotations
open Brahma.FSharp.Tests
open FsCheck

// Incomplete pattern matching in record deconstruction
#nowarn "667"
Expand Down Expand Up @@ -143,6 +144,17 @@ let recordTestCases = testList "Record tests" [
if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>)
]

let genGenericStruct<'a, 'b> =
gen {
let! x = Arb.generate<'a>
let! y = Arb.generate<'b>

return GenericStruct(x, y)
}

type GenericStructGenerator =
static member GenericStruct() = Arb.fromGen genGenericStruct

let structTests = testList "Struct tests" [
testCase "Smoke test" <| fun _ ->
let command =
Expand Down Expand Up @@ -190,16 +202,36 @@ let structTests = testList "Struct tests" [

checkResult command [|StructOfIntInt64(1, 2L); StructOfIntInt64(3, 4L)|]
[|StructOfIntInt64(4, 2L); StructOfIntInt64(3, 4L)|]
]

// let nestedTypesTests = testList "" [
let inline command length =
<@
fun (gid: int) (buffer: ClArray<GenericStruct<'a, 'b>>) ->
if gid < length then
let tmp = buffer.[gid]
let x = tmp.X
let y = tmp.Y
let mutable innerStruct = GenericStruct(x, y)
innerStruct.X <- x
innerStruct.Y <- y
buffer.[gid] <- GenericStruct(innerStruct.X, innerStruct.Y)
@>

let config = { FsCheckConfig.defaultConfig with arbitrary = [typeof<GenericStructGenerator>] }

testPropertyWithConfig config (message "GenericStruct<int, bool>") <| fun (data: GenericStruct<int, bool>[]) ->
if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>)

// ]
testPropertyWithConfig config (message "GenericStruct<(int * int64), (bool * bool)>") <| fun (data: GenericStruct<(int * int64), (bool * bool)>[]) ->
if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>)

testPropertyWithConfig config (message "GenericStruct<RecordOfIntInt64, RecordOfBoolBool>") <| fun (data: GenericStruct<RecordOfIntInt64, RecordOfBoolBool>[]) ->
if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>)
]

let tests =
testList "Tests on composite types" [
// tupleTestCases
// recordTestCases
tupleTestCases
recordTestCases
structTests
]
|> testSequenced
2 changes: 2 additions & 0 deletions tests/Brahma.FSharp.Tests/Expected/Barrier.Full.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__kernel void brahmaKernel ()
{barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE) ;}
2 changes: 2 additions & 0 deletions tests/Brahma.FSharp.Tests/Expected/Barrier.Global.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__kernel void brahmaKernel ()
{barrier(CLK_GLOBAL_MEM_FENCE) ;}
2 changes: 2 additions & 0 deletions tests/Brahma.FSharp.Tests/Expected/Barrier.Local.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__kernel void brahmaKernel ()
{barrier(CLK_LOCAL_MEM_FENCE) ;}
Loading

0 comments on commit 7aea127

Please sign in to comment.