diff --git a/src/Brahma.FSharp.OpenCL.AST/Expressions.fs b/src/Brahma.FSharp.OpenCL.AST/Expressions.fs index 8dbf4297..459c526c 100644 --- a/src/Brahma.FSharp.OpenCL.AST/Expressions.fs +++ b/src/Brahma.FSharp.OpenCL.AST/Expressions.fs @@ -66,6 +66,7 @@ type BOp<'lang> = | Pow | BitAnd | BitOr + | BitXor | LeftShift | RightShift | And @@ -91,6 +92,7 @@ type UOp<'lang> = | Not | Incr | Decr + | BitNegation type Unop<'lang>(op: UOp<'lang>, expr: Expression<'lang>) = inherit Expression<'lang>() diff --git a/src/Brahma.FSharp.OpenCL.AST/Statements.fs b/src/Brahma.FSharp.OpenCL.AST/Statements.fs index 4d7a2d56..640da988 100644 --- a/src/Brahma.FSharp.OpenCL.AST/Statements.fs +++ b/src/Brahma.FSharp.OpenCL.AST/Statements.fs @@ -77,7 +77,7 @@ type ForIntegerLoop<'lang> ( var: VarDecl<'lang>, cond: Expression<'lang>, - countModifier: Expression<'lang>, + countModifier: Statement<'lang>, body: StatementBlock<'lang> ) = @@ -94,9 +94,15 @@ type WhileLoop<'lang>(cond: Expression<'lang>, whileBlock: StatementBlock<'lang> member this.Condition = cond member this.WhileBlock = whileBlock -type Barrier<'lang>() = +type MemFence = + | Local + | Global + | Both + +type Barrier<'lang>(memFence: MemFence) = inherit Statement<'lang>() override this.Children = [] + member this.MemFence = memFence type FieldSet<'lang>(host: Expression<'lang>, field: string, _val: Expression<'lang>) = inherit Statement<'lang>() diff --git a/src/Brahma.FSharp.OpenCL.Core/Brahma.FSharp.OpenCL.Core.fsproj b/src/Brahma.FSharp.OpenCL.Core/Brahma.FSharp.OpenCL.Core.fsproj index 139dbfd3..e1b090dd 100644 --- a/src/Brahma.FSharp.OpenCL.Core/Brahma.FSharp.OpenCL.Core.fsproj +++ b/src/Brahma.FSharp.OpenCL.Core/Brahma.FSharp.OpenCL.Core.fsproj @@ -20,14 +20,14 @@ - + - + @@ -39,4 +39,4 @@ - \ No newline at end of file + diff --git a/src/Brahma.FSharp.OpenCL.Core/ClContext.fs b/src/Brahma.FSharp.OpenCL.Core/ClContext.fs index 60cb6051..f40c4b00 100644 --- a/src/Brahma.FSharp.OpenCL.Core/ClContext.fs +++ b/src/Brahma.FSharp.OpenCL.Core/ClContext.fs @@ -35,6 +35,7 @@ type ClDeviceType = | GPU -> DeviceType.Gpu | Default -> DeviceType.Default +// TODO redesign module internal Device = open System.Text.RegularExpressions @@ -82,7 +83,7 @@ type ClContext private (context: Context, device: Device, translator: FSQuotatio ctx - let translator = FSQuotationToOpenCLTranslator() + let translator = FSQuotationToOpenCLTranslator(TranslatorOptions()) let queue = CommandQueueProvider.CreateQueue(context, device) ClContext(context, device, translator, queue) @@ -101,7 +102,7 @@ type ClContext private (context: Context, device: Device, translator: FSQuotatio member this.WithNewCommandQueue() = ClContext(this.Context, this.Device, this.Translator, CommandQueueProvider.CreateQueue(this.Context, this.Device)) - member this.CreateClKernel(srcLambda: Expr<'a -> 'b>) = + member this.CreateClProgram(srcLambda: Expr<'a -> 'b>) = ClProgram<_,_>(this, srcLambda) member this.CreateClBuffer diff --git a/src/Brahma.FSharp.OpenCL.Core/ClKernel.fs b/src/Brahma.FSharp.OpenCL.Core/ClKernel.fs deleted file mode 100644 index 9fa4ee39..00000000 --- a/src/Brahma.FSharp.OpenCL.Core/ClKernel.fs +++ /dev/null @@ -1,184 +0,0 @@ -namespace Brahma.FSharp.OpenCL - -open OpenCL.Net -open Microsoft.FSharp.Quotations -open FSharp.Quotations.Evaluator -open Brahma.FSharp.OpenCL.Translator -open System -open System.Runtime.InteropServices -open Brahma.FSharp.OpenCL.Shared -open Brahma.FSharp.OpenCL.Translator.QuotationTransformers - -type ClProgram<'TRange, 'a when 'TRange :> INDRangeDimension> ( - clContext: IContext, - srcLambda: Expr<'TRange ->'a> - ) = - - let (clCode, newLambda) = - let (ast, newLambda) = clContext.Translator.Translate(srcLambda) - let code = Printer.AST.print ast - code, newLambda - - let program = - let (program, error) = - let sources = [|clCode|] - Cl.CreateProgramWithSource(clContext.Context, uint32 sources.Length, sources, null) - - if error <> ErrorCode.Success then - failwithf "Program creation failed: %A" error - - let options = " -cl-fast-relaxed-math -cl-mad-enable -cl-unsafe-math-optimizations " - let error = Cl.BuildProgram(program, 1u, [| clContext.Device |], options, null, IntPtr.Zero) - - if error <> ErrorCode.Success then - let errorCode = ref ErrorCode.Success - let buildInfo = Cl.GetProgramBuildInfo(program, clContext.Device, ProgramBuildInfo.Log, errorCode) - failwithf "Program compilation failed: %A \n BUILD LOG:\n %A \n" error (buildInfo) - - program - - let mutexBuffers = ResizeArray>() - - member this.GetNewKernel (?kernelName) = new ClKernel<'TRange, 'a>(this, ?kernelName=kernelName) - member this.Program = program - member this.Code = clCode - member this.ArgumentsSetter range setupArgument = - let args = ref [||] - let getStarterFunction qExpr = - match qExpr with - | DerivedPatterns.Lambdas (lambdaArgs, _) -> - let flattenArgs = List.collect id lambdaArgs - - let firstMutexIdx = - flattenArgs - |> List.tryFindIndex (fun v -> v.Name.EndsWith "Mutex") - |> Option.defaultValue flattenArgs.Length - - let argsWithoutMutexes = flattenArgs.[0 .. firstMutexIdx - 1] - - let mutexLengths = - let atomicVars = - List.init (flattenArgs.Length - firstMutexIdx) <| fun i -> - let mutexVar = flattenArgs.[firstMutexIdx + i] - argsWithoutMutexes |> List.find (fun v -> mutexVar.Name.Contains v.Name) - - Expr.NewArray( - typeof, - - atomicVars - |> List.map (fun var -> - match var with - | var when var.Type.Name.ToLower().StartsWith ClArray_ -> - Expr.PropertyGet( - Expr.Var var, - typeof> - .GetGenericTypeDefinition() - .MakeGenericType(var.Type.GenericTypeArguments.[0]) - .GetProperty("Length") - ) - - | var when var.Type.Name.ToLower().StartsWith ClCell_ -> - Expr.Value 1 - - | _ -> - failwithf "Something went wrong with type of atomic global var. \ - Expected var of type '%s' or '%s', but given %s" ClArray_ ClCell_ var.Type.Name - ) - ) - - let regularArgs = - Expr.NewArray( - typeof, - argsWithoutMutexes |> List.map (fun v -> Expr.Coerce(Expr.Var v, typeof)) - ) - - Expr.Lambdas( - argsWithoutMutexes - |> List.map List.singleton, - - <@@ - let mutexArgs = - (%%mutexLengths : int[]) - |> List.ofArray - |> List.map (fun n -> - let mutexBuffer = new ClBuffer(clContext, Size n) - mutexBuffers.Add mutexBuffer - box mutexBuffer - ) - - let x = %%regularArgs |> List.ofArray - range := unbox<'TRange> x.Head - args := x.Tail @ mutexArgs |> Array.ofList - - !args - |> Array.iteri setupArgument - @@> - ) - - | _ -> failwithf "Invalid kernel expression. Must be lambda, but given\n%O" qExpr - - |> fun kernelPrepare -> - <@ %%kernelPrepare: 'TRange -> 'a @>.Compile() - - getStarterFunction newLambda - - member this.ReleaseBuffers() = - mutexBuffers - |> Seq.iter (Msg.CreateFreeMsg >> clContext.CommandQueue.Post) - - mutexBuffers.Clear() - - -and ClKernel<'TRange, 'a when 'TRange :> INDRangeDimension> - ( - clProgram: ClProgram<'TRange, 'a>, - //setArgsLambda: 'TRange ->'a, - ?kernelName:string - ) = - - let kernelName = defaultArg kernelName "brahmaKernel" - - let createKernel program = - let (clKernel, error) = Cl.CreateKernel(program, kernelName) - if error <> ErrorCode.Success then - failwithf "OpenCL kernel creation problem. Error: %A" error - clKernel - - let kernel = - let clKernel = createKernel clProgram.Program - clKernel - - let toIMem a = - // TODO extend types for private args (now only int supported) - match box a with - | :? IClMem as buf -> buf.Size, buf.Data - | :? int as i -> IntPtr(Marshal.SizeOf i), box i - | other -> failwithf "Unexpected argument: %A" other - - let setupArgument index arg = - let (argSize, argVal) = toIMem arg - // NOTE SetKernelArg could take intptr - // TODO try allocate unmanaged mem by hand - let error = Cl.SetKernelArg(kernel, uint32 index, argSize, argVal) - if error <> ErrorCode.Success then - raise (CLException error) - - let range = ref Unchecked.defaultof<'TRange> - - interface IKernel<'TRange, 'a> with - member this.ArgumentsSetter = - clProgram.ArgumentsSetter range setupArgument - - member this.Kernel = kernel - member this.Range = range.Value :> INDRangeDimension - member this.Code = clProgram.Code - member this.ReleaseBuffers() = clProgram.ReleaseBuffers() - - member this.ArgumentsSetter = (this :> IKernel<_,_>).ArgumentsSetter - member this.Kernel = (this :> IKernel<_,_>).Kernel - member this.Range = (this :> IKernel<_,_>).Range - member this.Code = (this :> IKernel<_,_>).Code - - // TODO rename ?? ReleaseInternalBuffers - /// Освобождает только временные промежуточные утилитарные буферы (например, буфер для мьютексов) - member this.ReleaseBuffers() = (this :> IKernel<_,_>).ReleaseBuffers() diff --git a/src/Brahma.FSharp.OpenCL.Core/ClProgram.fs b/src/Brahma.FSharp.OpenCL.Core/ClProgram.fs new file mode 100644 index 00000000..68d797cc --- /dev/null +++ b/src/Brahma.FSharp.OpenCL.Core/ClProgram.fs @@ -0,0 +1,158 @@ +namespace Brahma.FSharp.OpenCL + +open OpenCL.Net +open Microsoft.FSharp.Quotations +open FSharp.Quotations.Evaluator +open Brahma.FSharp.OpenCL.Translator +open Brahma.FSharp.OpenCL.Printer +open System +open System.Runtime.InteropServices +open Brahma.FSharp.OpenCL.Shared +open Brahma.FSharp.OpenCL.Translator.QuotationTransformers + +type ClProgram<'TRange, 'a when 'TRange :> INDRange> + ( + clContext: IContext, + srcLambda: Expr<'TRange ->'a> + ) = + + let (clCode, newLambda) = + let (ast, newLambda) = clContext.Translator.Translate(srcLambda) + let code = AST.print ast + code, newLambda + + let program = + let (program, error) = + let sources = [|clCode|] + Cl.CreateProgramWithSource(clContext.Context, uint32 sources.Length, sources, null) + + if error <> ErrorCode.Success then + failwithf "Program creation failed: %A" error + + let options = " -cl-fast-relaxed-math -cl-mad-enable -cl-unsafe-math-optimizations " + let error = Cl.BuildProgram(program, 1u, [| clContext.Device |], options, null, IntPtr.Zero) + + if error <> ErrorCode.Success then + let errorCode = ref ErrorCode.Success + let buildInfo = Cl.GetProgramBuildInfo(program, clContext.Device, ProgramBuildInfo.Log, errorCode) + failwithf "Program compilation failed: %A \n BUILD LOG:\n %A \n" error (buildInfo) + + program + + let createKernel program kernelName = + let (clKernel, error) = Cl.CreateKernel(program, kernelName) + if error <> ErrorCode.Success then + failwithf "OpenCL kernel creation problem. Error: %A" error + clKernel + + member this.Program = program + + member this.Code = clCode + + member this.GetKernel(?kernelName) = + let kernelName = defaultArg kernelName "brahmaKernel" + let kernel = createKernel program kernelName + + let toIMem arg = + match box arg with + | :? IClMem as buf -> buf.Size, buf.Data + | :? int as i -> IntPtr(Marshal.SizeOf i), box i + | other -> failwithf "Unexpected argument: %A" other + + let setupArgument index (arg: obj) = + let (argSize, argVal) = toIMem arg + let error = Cl.SetKernelArg(kernel, uint32 index, argSize, argVal) + if error <> ErrorCode.Success then + raise (CLException error) + + let args = ref [||] + let range = ref Unchecked.defaultof<'TRange> + let mutexBuffers = ResizeArray>() + + let argumentsSetterFunc = + match newLambda with + | DerivedPatterns.Lambdas (lambdaArgs, _) -> + let flattenArgs = List.collect id lambdaArgs + + let firstMutexIdx = + flattenArgs + |> List.tryFindIndex (fun v -> v.Name.EndsWith "Mutex") + |> Option.defaultValue flattenArgs.Length + + let argsWithoutMutexes = flattenArgs.[0 .. firstMutexIdx - 1] + + let mutexLengths = + let atomicVars = + List.init (flattenArgs.Length - firstMutexIdx) <| fun i -> + let mutexVar = flattenArgs.[firstMutexIdx + i] + argsWithoutMutexes |> List.find (fun v -> mutexVar.Name.Contains v.Name) + + Expr.NewArray( + typeof, + + atomicVars + |> List.map (fun var -> + match var with + | var when var.Type.Name.ToLower().StartsWith ClArray_ -> + Expr.PropertyGet( + Expr.Var var, + typeof> + .GetGenericTypeDefinition() + .MakeGenericType(var.Type.GenericTypeArguments.[0]) + .GetProperty("Length") + ) + + | var when var.Type.Name.ToLower().StartsWith ClCell_ -> + Expr.Value 1 + + | _ -> + failwithf "Something went wrong with type of atomic global var. \ + Expected var of type '%s' or '%s', but given %s" ClArray_ ClCell_ var.Type.Name + ) + ) + + let regularArgs = + Expr.NewArray( + typeof, + argsWithoutMutexes |> List.map (fun v -> Expr.Coerce(Expr.Var v, typeof)) + ) + + Expr.Lambdas( + argsWithoutMutexes + |> List.map List.singleton, + + <@@ + let mutexArgs = + (%%mutexLengths : int[]) + |> List.ofArray + |> List.map (fun n -> + let mutexBuffer = new ClBuffer(clContext, Size n) + mutexBuffers.Add mutexBuffer + box mutexBuffer + ) + + let x = %%regularArgs |> List.ofArray + range := unbox<'TRange> x.Head + args := x.Tail @ mutexArgs |> Array.ofList + + !args + |> Array.iteri setupArgument + @@> + ) + + | _ -> failwithf "Invalid kernel expression. Must be lambda, but given\n%O" newLambda + + |> fun kernelPrepare -> + <@ %%kernelPrepare : 'TRange -> 'a @>.Compile() + + { new IKernel<'TRange, 'a> with + member _.Kernel = kernel + member _.NDRange = range.Value :> INDRange + member _.KernelFunc = argumentsSetterFunc + member _.ReleaseInternalBuffers() = + mutexBuffers + |> Seq.iter (Msg.CreateFreeMsg >> clContext.CommandQueue.Post) + + mutexBuffers.Clear() + } + diff --git a/src/Brahma.FSharp.OpenCL.Core/ClTask.fs b/src/Brahma.FSharp.OpenCL.Core/ClTask.fs index 31b4b02d..76eab3b7 100644 --- a/src/Brahma.FSharp.OpenCL.Core/ClTask.fs +++ b/src/Brahma.FSharp.OpenCL.Core/ClTask.fs @@ -83,7 +83,14 @@ module ClTask = context.CommandQueue.PostAndReply <| MsgNotifyMe res - // TODO maybe switсh to manual threads + // TODO implement + // let startSync (ClTask f) = + // let context = Device.getFirstAppropriateDevice + // let res = f context + // context.CommandQueue.PostAndReply <| MsgNotifyMe + // res + + // NOTE maybe switсh to manual threads // TODO check if it is really parallel let inParallel (tasks: seq>) = opencl { let! ctx = ask @@ -115,16 +122,16 @@ module ClTaskOpened = opencl { let! ctx = ClTask.ask - let kernel = (ctx.CreateClKernel command).GetNewKernel() + let kernel = ctx.CreateClProgram(command).GetKernel() - ctx.CommandQueue.Post <| MsgSetArguments(fun () -> binder kernel.ArgumentsSetter) + ctx.CommandQueue.Post <| MsgSetArguments(fun () -> binder kernel.KernelFunc) ctx.CommandQueue.Post <| Msg.CreateRunMsg<_, _>(kernel) - kernel.ReleaseBuffers() + kernel.ReleaseInternalBuffers() } - let runKernel (kernel: ClKernel<'range, 'a>) (processor: MailboxProcessor) (binder: ('range -> 'a) -> unit) : ClTask = + let runKernel (kernel: IKernel<'range, 'a>) (processor: MailboxProcessor) (binder: ('range -> 'a) -> unit) : ClTask = opencl { - processor.Post <| MsgSetArguments(fun () -> binder kernel.ArgumentsSetter) + processor.Post <| MsgSetArguments(fun () -> binder kernel.KernelFunc) processor.Post <| Msg.CreateRunMsg<_, _>(kernel) - kernel.ReleaseBuffers() + kernel.ReleaseInternalBuffers() } diff --git a/src/Brahma.FSharp.OpenCL.Core/CommandQueueProvider.fs b/src/Brahma.FSharp.OpenCL.Core/CommandQueueProvider.fs index 3b6189a0..ad5fcbe8 100644 --- a/src/Brahma.FSharp.OpenCL.Core/CommandQueueProvider.fs +++ b/src/Brahma.FSharp.OpenCL.Core/CommandQueueProvider.fs @@ -129,7 +129,7 @@ type CommandQueueProvider = static member private HandleRun(queue, run: IRunCrate) = { new IRunCrateEvaluator with member this.Eval crate = - let range = crate.Kernel.Range + let range = crate.Kernel.NDRange let workDim = uint32 range.Dimensions let eventID = ref Unchecked.defaultof let error = Cl.EnqueueNDRangeKernel(queue, crate.Kernel.Kernel, workDim, null, diff --git a/src/Brahma.FSharp.OpenCL.Core/CustomMarshaler.fs b/src/Brahma.FSharp.OpenCL.Core/CustomMarshaler.fs index c6c14a6a..bd695caf 100644 --- a/src/Brahma.FSharp.OpenCL.Core/CustomMarshaler.fs +++ b/src/Brahma.FSharp.OpenCL.Core/CustomMarshaler.fs @@ -5,6 +5,7 @@ open System.Runtime.InteropServices open Brahma.FSharp.OpenCL.Translator open FSharp.Reflection open System.Runtime.CompilerServices +open System.Threading.Tasks type StructurePacking = | StructureElement of {| Size: int; Aligment: int |} * StructurePacking list @@ -15,6 +16,7 @@ module private Utils = |> Seq.tryFind (fun attr -> attr.GetType() = typeof<'attr>) |> Option.isSome +// TODO make read write parallel type CustomMarshaler<'a>() = let (|TupleType|RecordType|UnionType|UserDefinedStuctureType|PrimitiveType|) (type': Type) = match type' with @@ -81,7 +83,26 @@ type CustomMarshaler<'a>() = StructureElement({| Size = size; Aligment = aligment |}, elems) | UnionType -> failwithf "Union not supported" - | UserDefinedStuctureType -> failwithf "Custom structures not supported" + + | UserDefinedStuctureType -> + let elems = + type'.GetFields() + |> Array.map (fun fi -> fi.FieldType) + |> Array.map go + |> Array.toList + + let aligment = + elems + |> List.map (fun (StructureElement(pack, _)) -> pack.Aligment) + |> List.max + + let size = + elems + |> List.map (fun (StructureElement(pack, _)) -> pack) + |> List.fold (fun state x -> roundUp x.Aligment state + x.Size) 0 + |> roundUp aligment + + StructureElement({| Size = size; Aligment = aligment |}, elems) | PrimitiveType -> let size = Marshal.SizeOf (if type' = typeof then typeof else type') @@ -129,7 +150,7 @@ type CustomMarshaler<'a>() = size, mem member this.WriteToUnmanaged(array: 'a[], ptr: IntPtr) = - for j = 0 to array.Length - 1 do + Array.Parallel.iteri (fun j item -> let start = IntPtr.Add(ptr, j * this.ElementTypeSize) let mutable i = 0 let rec go (structure: obj) = @@ -140,10 +161,15 @@ type CustomMarshaler<'a>() = [ 0 .. tupleSize - 1 ] |> List.iter (fun i -> go tuple.[i]) | Record -> - FSharpValue.GetRecordFields structure |> Array.iter go + FSharpValue.GetRecordFields structure + |> Array.iter go | Union -> failwithf "Union not supported" - | UserDefinedStucture -> failwithf "Custom structures not supported" + + | UserDefinedStucture -> + structure.GetType().GetFields() + |> Array.map (fun fi -> fi.GetValue(structure)) + |> Array.iter go | Primitive -> let offset = this.ElementTypeOffsets.[i] @@ -155,7 +181,8 @@ type CustomMarshaler<'a>() = Marshal.StructureToPtr(structure, IntPtr.Add(start, offset), false) i <- i + 1 - go array.[j] + go item + ) array array.Length * this.ElementTypeSize @@ -165,7 +192,7 @@ type CustomMarshaler<'a>() = array member this.ReadFromUnmanaged(ptr: IntPtr, array: 'a[]) = - for j = 0 to array.Length - 1 do + Array.Parallel.iteri (fun j _ -> let start = IntPtr.Add(ptr, j * this.ElementTypeSize) let mutable i = 0 let rec go (type': Type) = @@ -182,7 +209,14 @@ type CustomMarshaler<'a>() = |> fun x -> FSharpValue.MakeRecord(type', x) | UnionType -> failwithf "Union not supported" - | UserDefinedStuctureType -> failwithf "Custom structures not supported" + + | UserDefinedStuctureType -> + let inst = Activator.CreateInstance(type') + type'.GetFields() + |> Array.map (fun fi -> fi, go fi.FieldType) + |> Array.iter (fun (fi, value) -> fi.SetValue(inst, value)) + + inst | PrimitiveType -> let offset = this.ElementTypeOffsets.[i] @@ -196,6 +230,7 @@ type CustomMarshaler<'a>() = structure array.[j] <- unbox<'a> <| go typeof<'a> + ) array override this.ToString() = sprintf "%O\n%A" elementPacking offsets diff --git a/src/Brahma.FSharp.OpenCL.Core/IKernel.fs b/src/Brahma.FSharp.OpenCL.Core/IKernel.fs index d337d631..b7c6e2ae 100644 --- a/src/Brahma.FSharp.OpenCL.Core/IKernel.fs +++ b/src/Brahma.FSharp.OpenCL.Core/IKernel.fs @@ -2,9 +2,9 @@ namespace Brahma.FSharp.OpenCL open OpenCL.Net -type IKernel<'TRange, 'a when 'TRange :> INDRangeDimension> = - abstract ArgumentsSetter : ('TRange -> 'a) +type IKernel<'TRange, 'a when 'TRange :> INDRange> = abstract Kernel : Kernel - abstract Range : INDRangeDimension - abstract Code : string - abstract ReleaseBuffers : unit -> unit + abstract NDRange : INDRange + // not sure about naming + abstract KernelFunc : ('TRange -> 'a) + abstract ReleaseInternalBuffers : unit -> unit diff --git a/src/Brahma.FSharp.OpenCL.Core/Messages.fs b/src/Brahma.FSharp.OpenCL.Core/Messages.fs index 3ed2e6be..15a9c624 100644 --- a/src/Brahma.FSharp.OpenCL.Core/Messages.fs +++ b/src/Brahma.FSharp.OpenCL.Core/Messages.fs @@ -14,7 +14,7 @@ type ToGPU<'a when 'a: struct>(src: 'a[], dst: IBuffer<'a>) = member this.Destination = dst member this.Source = src -type Run<'TRange, 'a when 'TRange :> INDRangeDimension>(kernel: IKernel<'TRange, 'a>) = +type Run<'TRange, 'a when 'TRange :> INDRange>(kernel: IKernel<'TRange, 'a>) = member this.Kernel = kernel type IRunCrate = @@ -76,7 +76,7 @@ type Msg = } |> MsgFree - static member CreateRunMsg<'TRange, 'a when 'TRange :> INDRangeDimension>(kernel) = + static member CreateRunMsg<'TRange, 'a when 'TRange :> INDRange>(kernel) = { new IRunCrate with member this.Apply evaluator = evaluator.Eval <| Run<'TRange, 'a>(kernel) } diff --git a/src/Brahma.FSharp.OpenCL.Core/NDRangeDimensions.fs b/src/Brahma.FSharp.OpenCL.Core/NDRange.fs similarity index 95% rename from src/Brahma.FSharp.OpenCL.Core/NDRangeDimensions.fs rename to src/Brahma.FSharp.OpenCL.Core/NDRange.fs index a8562253..bbe48d86 100644 --- a/src/Brahma.FSharp.OpenCL.Core/NDRangeDimensions.fs +++ b/src/Brahma.FSharp.OpenCL.Core/NDRange.fs @@ -2,7 +2,7 @@ namespace Brahma.FSharp.OpenCL open System -type INDRangeDimension = +type INDRange = abstract member GlobalWorkSize: IntPtr[] with get abstract member LocalWorkSize: IntPtr[] with get abstract member Dimensions: int @@ -16,7 +16,7 @@ type Range1D(globalWorkSize: int, localWorkSize: int) = member this.GlobalWorkSize = globalWorkSize member this.LocalWorkSize = localWorkSize - interface INDRangeDimension with + interface INDRange with member this.GlobalWorkSize with get () = [| IntPtr globalWorkSize |] member this.LocalWorkSize with get () = [| IntPtr localWorkSize |] member this.Dimensions = 1 @@ -36,7 +36,7 @@ type Range2D(globalWorkSizeX: int, globalWorkSizeY: int, localWorkSizeX: int, lo member this.GlobalWorkSize = (globalWorkSizeX, globalWorkSizeY) member this.LocalWorkSize = (localWorkSizeX, localWorkSizeY) - interface INDRangeDimension with + interface INDRange with member this.GlobalWorkSize with get () = [| IntPtr globalWorkSizeX; IntPtr globalWorkSizeY |] member this.LocalWorkSize with get () = [| IntPtr localWorkSizeX; IntPtr localWorkSizeY |] member this.Dimensions = 2 @@ -54,7 +54,7 @@ type Range3D(globalWorkSizeX: int, globalWorkSizeY: int, globalWorkSizeZ: int, l member this.GlobalWorkSize = (globalWorkSizeX, globalWorkSizeY, globalWorkSizeZ) member this.LocalWorkSize = (localWorkSizeX, localWorkSizeY, localWorkSizeZ) - interface INDRangeDimension with + interface INDRange with member this.GlobalWorkSize with get () = [| IntPtr globalWorkSizeX; IntPtr globalWorkSizeY; IntPtr globalWorkSizeZ |] member this.LocalWorkSize with get () = [| IntPtr localWorkSizeX; IntPtr localWorkSizeY; IntPtr globalWorkSizeZ |] member this.Dimensions = 3 diff --git a/src/Brahma.FSharp.OpenCL.Printer/Expressions.fs b/src/Brahma.FSharp.OpenCL.Printer/Expressions.fs index 0e1f2a4b..5b983f16 100644 --- a/src/Brahma.FSharp.OpenCL.Printer/Expressions.fs +++ b/src/Brahma.FSharp.OpenCL.Printer/Expressions.fs @@ -59,10 +59,11 @@ module Expressions = | Pow -> "+" | BitAnd -> "&" | BitOr -> "|" - | And -> "&&" - | Or -> "||" + | BitXor -> "^" | LeftShift -> "<<" | RightShift -> ">>" + | And -> "&&" + | Or -> "||" | Less -> "<" | LessEQ -> "<=" | Great -> ">" @@ -109,6 +110,7 @@ module Expressions = | UOp.Not -> wordL "!" ++ print uo.Expr |> bracketL | UOp.Incr -> print uo.Expr ++ wordL "++" | UOp.Decr -> print uo.Expr ++ wordL "--" + | UOp.BitNegation -> wordL "~" ++ print uo.Expr |> bracketL and private printCast (c: Cast<'lang>) = let t = Types.print c.Type @@ -135,8 +137,11 @@ module Expressions = and printNewStruct (newStruct: NewStruct<_>) = let args = List.map print newStruct.ConstructorArgs |> commaListL - let t = Types.print newStruct.Struct - [ t |> bracketL; wordL "{"; args; wordL "}" ] |> spaceListL + match newStruct.Struct with + | :? StructInplaceType<_> -> [ wordL "{"; args; wordL "}" ] |> spaceListL + | _ -> + let t = Types.print newStruct.Struct + [ t |> bracketL; wordL "{"; args; wordL "}" ] |> spaceListL and printNewUnion (newUnion: NewUnion<_>) = let arg = print newUnion.ConstructorArg diff --git a/src/Brahma.FSharp.OpenCL.Printer/Printer.fs b/src/Brahma.FSharp.OpenCL.Printer/Printer.fs index 154ea18a..af8a0f3d 100644 --- a/src/Brahma.FSharp.OpenCL.Printer/Printer.fs +++ b/src/Brahma.FSharp.OpenCL.Printer/Printer.fs @@ -16,8 +16,7 @@ namespace Brahma.FSharp.OpenCL.Printer open Brahma.FSharp.OpenCL.AST -open Microsoft.FSharp.Text -open Microsoft.FSharp.Text.StructuredFormat.LayoutOps +open Microsoft.FSharp.Text.StructuredFormat open Brahma.FSharp.OpenCL.Printer module AST = @@ -28,10 +27,12 @@ module AST = match d with | :? FunDecl<'lang> as fd -> FunDecl.print fd | :? CLPragma<'lang> as clp -> Pragmas.print clp - | :? StructDecl<'lang> as s -> TypeDecl.PrintStructDeclaration s + | :? StructDecl<'lang> as s -> TypeDecl.printStructDeclaration s | :? VarDecl<'lang> as s -> Statements.print false s | _ -> failwithf "Printer. Unsupported toplevel declaration: %A" d ) - |> aboveListL - |> StructuredFormat.Display.layout_to_string { StructuredFormat.FormatOptions.Default with PrintWidth = 100 } + // |> LayoutOps.sepListL (LayoutOps.wordL "\r\n") + // |> Display.layout_to_string FormatOptions.Default + |> LayoutOps.aboveListL + |> Display.layout_to_string { FormatOptions.Default with PrintWidth = 100 } diff --git a/src/Brahma.FSharp.OpenCL.Printer/Statements.fs b/src/Brahma.FSharp.OpenCL.Printer/Statements.fs index b7b0041b..3697d54f 100644 --- a/src/Brahma.FSharp.OpenCL.Printer/Statements.fs +++ b/src/Brahma.FSharp.OpenCL.Printer/Statements.fs @@ -82,7 +82,7 @@ module Statements = and private printForInteger (for': ForIntegerLoop<_>) = let cond = Expressions.print for'.Condition let i = print true for'.Var - let cModif = Expressions.print for'.CountModifier + let cModif = print true for'.CountModifier let body = print true for'.Body let header = [ i; cond; cModif ] |> sepListL (wordL ";") |> bracketL @@ -111,7 +111,11 @@ module Statements = wordL fc.Name ++ args - and printBarrier (b: Barrier<_>) = wordL "barrier(CLK_LOCAL_MEM_FENCE)" + and printBarrier (b: Barrier<_>) = + match b.MemFence with + | MemFence.Local -> wordL "barrier(CLK_LOCAL_MEM_FENCE)" + | MemFence.Global -> wordL "barrier(CLK_GLOBAL_MEM_FENCE)" + | Both -> wordL "barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE)" and printReturn (r: Return<_>) = wordL "return" ++ Expressions.print r.Expression @@ -143,6 +147,7 @@ module Statements = | :? FieldSet<'lang> as fs -> printFieldSet fs | :? Return<'lang> as r -> printReturn r //| :? Variable<'lang> as v -> printVar v + | :? Expression<'lang> as e -> Expressions.print e | _ -> failwithf "Printer. Unsupported statement: %O" stmt if isToplevel then diff --git a/src/Brahma.FSharp.OpenCL.Printer/TypeDecl.fs b/src/Brahma.FSharp.OpenCL.Printer/TypeDecl.fs index 003d5833..a0dd07a0 100644 --- a/src/Brahma.FSharp.OpenCL.Printer/TypeDecl.fs +++ b/src/Brahma.FSharp.OpenCL.Printer/TypeDecl.fs @@ -20,7 +20,7 @@ open Brahma.FSharp.OpenCL.Printer open Microsoft.FSharp.Text.StructuredFormat.LayoutOps module TypeDecl = - let PrintStructDeclaration (decl: StructDecl<_>) = + let printStructDeclaration (decl: StructDecl<_>) = let header = [ wordL "typedef" diff --git a/src/Brahma.FSharp.OpenCL.Shared/KernelLangExtensions.fs b/src/Brahma.FSharp.OpenCL.Shared/KernelLangExtensions.fs index 1f23d8e2..dc8c35ab 100644 --- a/src/Brahma.FSharp.OpenCL.Shared/KernelLangExtensions.fs +++ b/src/Brahma.FSharp.OpenCL.Shared/KernelLangExtensions.fs @@ -1,16 +1,24 @@ namespace Brahma.FSharp.OpenCL [] -type KernelLangExtentions = +type KernelLangExtensions = static member FailIfOutsideKernel() = failwith "Seems that you try to use openCL kernel function as regular F# function!" [] -module KernelLangExtentions = +module KernelLangExtensions = let inline internal failIfOutsideKernel () = failwith "Seems that you try to use openCL kernel function as regular F# function!" - let barrier () = + let barrierLocal () = + failIfOutsideKernel () + ignore null + + let barrierGlobal () = + failIfOutsideKernel () + ignore null + + let barrierFull () = failIfOutsideKernel () ignore null diff --git a/src/Brahma.FSharp.OpenCL.Translator/Body.fs b/src/Brahma.FSharp.OpenCL.Translator/Body.fs index 495566d1..d59518d2 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Body.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Body.fs @@ -22,15 +22,94 @@ open Microsoft.FSharp.Collections open FSharpx.Collections open Brahma.FSharp.OpenCL.Translator.QuotationTransformers open Brahma.FSharp.OpenCL +open FSharp.Quotations.Evaluator // Translations restricts the generic parameter of the AST nodes to the type Lang #nowarn "64" +[] +module private BodyPatterns = + let (|VarName|_|) (str: string) (var': Var) = + match var'.Name with + | tName when tName = str -> Some VarName + | _ -> None + + let (|ForLoopWithStep|_|) = function + | Patterns.Let + ( + VarName "inputSequence", + DerivedPatterns.SpecificCall <@ (.. ..) @> ( + _, + _, + [start; step; finish] + ), + Patterns.Let ( + VarName "enumerator", + _, + Patterns.TryFinally ( + Patterns.WhileLoop ( + _, + Patterns.Let ( + loopVar, + _, + loopBody + ) + ), + _ + ) + ) + ) -> Some (loopVar, (start, step, finish), loopBody) + | _ -> None + + let (|ForLoop|_|) = function + | Patterns.Let + ( + VarName "inputSequence", + DerivedPatterns.SpecificCall <@ (..) @> ( + _, + _, + [start; finish] + ), + Patterns.Let ( + VarName "enumerator", + _, + Patterns.TryFinally ( + Patterns.WhileLoop ( + _, + Patterns.Let ( + loopVar, + _, + loopBody + ) + ), + _ + ) + ) + ) -> Some (loopVar, (start, finish), loopBody) + | _ -> None + module rec Body = // new var scope let private clearContext (targetContext: TranslationContext<'a, 'b>) = { targetContext with VarDecls = ResizeArray() } + let toStb (s: Node<_>) = translation { + match s with + | :? StatementBlock<_> as s -> + return s + | x -> return StatementBlock <| ResizeArray [x :?> Statement<_>] + } + + let private itemHelper exprs hostVar = translation { + let! idx = translation { + match exprs with + | hd :: _ -> return! translateAsExpr hd + | [] -> return raise <| InvalidKernelException("Array index missed!") + } + + return idx, hostVar + } + let private translateBinding (var: Var) newName (expr: Expr) = translation { let! body = translateCond (*TranslateAsExpr*) expr let! varType = translation { @@ -75,52 +154,55 @@ module rec Body = | "op_modulus" -> return Binop(Remainder, args.[0], args.[1]) :> Statement<_> | "op_bitwiseand" -> return Binop(BitAnd, args.[0], args.[1]) :> Statement<_> | "op_bitwiseor" -> return Binop(BitOr, args.[0], args.[1]) :> Statement<_> + | "op_exclusiveor" -> return Binop(BitXor, args.[0], args.[1]) :> Statement<_> + | "op_logicalnot" -> return Unop(UOp.BitNegation, args.[0]) :> Statement<_> | "op_leftshift" -> return Binop(LeftShift, args.[0], args.[1]) :> Statement<_> | "op_rightshift" -> return Binop(RightShift, args.[0], args.[1]) :> Statement<_> | "op_booleanand" -> - let! flag = State.gets (fun context -> context.TranslatorOptions |> List.contains UseNativeBooleanType) + let! flag = State.gets (fun context -> context.TranslatorOptions.UseNativeBooleanType) if flag then return Binop(And, args.[0], args.[1]) :> Statement<_> else return Binop(BitAnd, args.[0], args.[1]) :> Statement<_> | "op_booleanor" -> - let! flag = State.gets (fun context -> context.TranslatorOptions |> List.contains UseNativeBooleanType) + let! flag = State.gets (fun context -> context.TranslatorOptions.UseNativeBooleanType) if flag then return Binop(Or, args.[0], args.[1]) :> Statement<_> else return Binop(BitOr, args.[0], args.[1]) :> Statement<_> + | "not" -> return Unop(UOp.Not, args.[0]) :> Statement<_> | "atomicadd" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_add", [args.[0]; args.[1]]) :> Statement<_> | "atomicsub" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_sub", [args.[0]; args.[1]]) :> Statement<_> | "atomicxchg" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_xchg", [args.[0]; args.[1]]) :> Statement<_> | "atomicmax" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_max", [args.[0]; args.[1]]) :> Statement<_> | "atomicmin" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_min", [args.[0]; args.[1]]) :> Statement<_> | "atomicinc" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_inc", [args.[0]]) :> Statement<_> | "atomicdec" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_dec", [args.[0]]) :> Statement<_> | "atomiccmpxchg" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_cmpxchg", [args.[0]; args.[1]; args.[2]]) :> Statement<_> | "atomicand" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_and", [args.[0]; args.[1]]) :> Statement<_> | "atomicor" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_or", [args.[0]; args.[1]]) :> Statement<_> | "atomicxor" -> - do! State.modify (fun context -> context.Flags.enableAtomic <- true; context) + do! State.modify (fun context -> context.Flags.Add EnableAtomic |> ignore; context) return FunCall("atom_xor", [args.[0]; args.[1]]) :> Statement<_> | "todouble" -> return Cast(args.[0], PrimitiveType Float) :> Statement<_> | "toint" -> return Cast(args.[0], PrimitiveType Int) :> Statement<_> @@ -131,6 +213,8 @@ module rec Body = | "touint16" -> return Cast(args.[0], PrimitiveType UShort) :> Statement<_> | "toint64" -> return Cast(args.[0], PrimitiveType Long) :> Statement<_> | "touint64" -> return Cast(args.[0], PrimitiveType ULong) :> Statement<_> + | "min" + | "max" | "acos" | "asin" | "atan" @@ -155,7 +239,6 @@ module rec Body = return raise <| InvalidKernelException( sprintf "Seems, that you use math function with name %s not from System.Math or Microsoft.FSharp.Core.Operators" fName ) - | "abs" as fName -> if mInfo.DeclaringType.AssemblyQualifiedName.StartsWith("Microsoft.FSharp.Core.Operators") then return FunCall("fabs", args) :> Statement<_> @@ -177,11 +260,11 @@ module rec Body = | "setarray" -> return Assignment(Property(PropertyType.Item(Item(args.[0], args.[1]))), args.[2]) :> Statement<_> | "getarray" -> return Item(args.[0], args.[1]) :> Statement<_> - | "not" -> return Unop(UOp.Not, args.[0]) :> Statement<_> - | "_byte" -> return args.[0] :> Statement<_> - | "barrier" -> return Barrier() :> Statement<_> + | "barrierlocal" -> return Barrier(MemFence.Local) :> Statement<_> + | "barrierglobal" -> return Barrier(MemFence.Global) :> Statement<_> + | "barrierfull" -> return Barrier(MemFence.Both) :> Statement<_> | "local" -> return raise <| InvalidKernelException("Calling the local function is allowed only at the top level of the let binding") - | "arrayLocal" -> return raise <| InvalidKernelException("Calling the localArray function is allowed only at the top level of the let binding") + | "arraylocal" -> return raise <| InvalidKernelException("Calling the localArray function is allowed only at the top level of the let binding") | "zerocreate" -> let length = match args.[0] with @@ -190,22 +273,9 @@ module rec Body = return ZeroArray length :> Statement<_> | "fst" -> return FieldGet(args.[0], "_1") :> Statement<_> | "snd" -> return FieldGet(args.[0], "_2") :> Statement<_> - | "first" -> return FieldGet(args.[0], "_1") :> Statement<_> - | "second" -> return FieldGet(args.[0], "_2") :> Statement<_> - | "third" -> return FieldGet(args.[0], "_3") :> Statement<_> | other -> return raise <| InvalidKernelException(sprintf "Unsupported call: %s" other) } - let private itemHelper exprs hostVar = translation { - let! idx = translation { - match exprs with - | hd :: _ -> return! translateAsExpr hd - | [] -> return raise <| InvalidKernelException("Array index missed!") - } - - return idx, hostVar - } - let private translateSpecificPropGet expr propName exprs = translation { // TODO: Refactoring: Safe pattern matching by expr type. let! hostVar = translateAsExpr expr @@ -245,9 +315,9 @@ module rec Body = match exprOpt with | Some expr -> - match! State.gets (fun context -> context.UserDefinedTypes.Contains expr.Type) with + match! State.gets (fun context -> context.CStructDecls.Keys |> Seq.contains expr.Type) with | true -> - match! State.gets (fun context -> context.StructDecls.ContainsKey expr.Type) with + match! State.gets (fun context -> not <| context.CStructDecls.[expr.Type] :? DiscriminatedUnionType<_>) with | true -> return! translateStructFieldGet expr propInfo.Name | false -> return! translateUnionFieldGet expr propInfo | false -> return! translateSpecificPropGet expr propName exprs @@ -295,14 +365,9 @@ module rec Body = return translated :?> Expression<_> } - let getVar (clVarName: string) = translation { - return Variable clVarName - } - let translateVar (var: Var) = translation { - //getVar var.Name targetContext match! State.gets (fun context -> context.Namer.GetCLVarName var.Name) with - | Some varName -> return! getVar varName + | Some varName -> return Variable varName | None -> return raise <| InvalidKernelException( sprintf @@ -355,7 +420,7 @@ module rec Body = let! l = translateCond if' let! r = translateCond then' let! e = translateCond else' - let! isBoolAsBit = State.gets (fun context -> context.TranslatorOptions |> List.contains BoolAsBit) + let! isBoolAsBit = State.gets (fun context -> context.TranslatorOptions.BoolAsBit) let o1 = match r with | :? Const as c when c.Val = "1" -> l @@ -369,13 +434,6 @@ module rec Body = | _ -> return! translateAsExpr cond } - let toStb (s: Node<_>) = translation { - match s with - | :? StatementBlock<_> as s -> - return s - | x -> return StatementBlock <| ResizeArray [x :?> Statement<_>] - } - let translateIf (cond: Expr) (thenBranch: Expr) (elseBranch: Expr) = translation { let! if' = translateCond cond let! then' = translate thenBranch >>= toStb |> State.using clearContext @@ -392,18 +450,37 @@ module rec Body = return IfThenElse(if', then', else') } - // TODO refac - let translateForIntegerRangeLoop (i: Var) (from': Expr) (to': Expr) (loopBody: Expr) = translation { - let! iName = State.gets (fun context -> context.Namer.LetStart i.Name) - let! v = getVar iName - let! var = translateBinding i iName from' + // NOTE reversed loops not supported + let translateForLoop (loopVar: Var) (from': Expr) (to': Expr) (step: Expr option) (body: Expr) = translation { + let! loopVarName = State.gets (fun context -> context.Namer.LetStart loopVar.Name) + let loopVarType = loopVar.Type + + let! loopVarBinding = translateBinding loopVar loopVarName from' + let! condExpr = translateAsExpr to' - do! State.modify (fun context -> context.Namer.LetIn i.Name; context) - let! body = translate loopBody >>= toStb |> State.using clearContext - let cond = Binop(LessEQ, v, condExpr) - let condModifier = Unop(UOp.Incr, v) + let loopCond = Binop(LessEQ, Variable loopVarName, condExpr) + + do! State.modify (fun context -> context.Namer.LetIn loopVar.Name; context) + + let! loopVarModifier = + match step with + | Some step -> + Expr.VarSet( + loopVar, + Expr.Call( + Utils.makeGenericMethodCall [loopVarType; loopVarType; loopVarType] <@ (+) @>, + [Expr.Var loopVar; step] + ) + ) + |> translate + |> State.map (fun node -> node :?> Statement<_>) + | None -> translation { return Unop(UOp.Incr, Variable loopVarName) :> Statement<_> } + + let! loopBody = translate body >>= toStb |> State.using clearContext + do! State.modify (fun context -> context.Namer.LetOut(); context) - return ForIntegerLoop(var, cond, condModifier, body) + + return ForIntegerLoop(loopVarBinding, loopCond, loopVarModifier, loopBody) } let translateWhileLoop condExpr bodyExpr = translation { @@ -428,7 +505,8 @@ module rec Body = do! State.modify (fun context -> context.VarDecls.Clear(); context) for expr in linearized do - do! State.modify (fun context -> context.VarDecls.Clear(); context) + // NOTE тут что то сломалось :( + // do! State.modify (fun context -> context.VarDecls.Clear(); context) match! translate expr with | :? StatementBlock as s1 -> decls.AddRange(s1.Statements) @@ -495,12 +573,18 @@ module rec Body = } let translateUnionFieldGet expr (propInfo: PropertyInfo) = translation { - let! unionType = State.gets (fun context -> context.UnionDecls.[expr.Type]) + let! unionType = State.gets (fun context -> context.CStructDecls.[expr.Type]) + let unionType = unionType :?> DiscriminatedUnionType let! unionValueExpr = translateAsExpr expr let caseName = propInfo.DeclaringType.Name - let unionCaseField = unionType.GetCaseByName caseName + let unionCaseField = + // для option классы наследники не создаются, поэтому нужно обрабатывать отдельно + if caseName <> "FSharpOption`1" then + unionType.GetCaseByName caseName + else + unionType.GetCaseByName "Some" match unionCaseField with | Some unionCaseField -> @@ -519,20 +603,83 @@ module rec Body = ) } + let private translateLet (var: Var) expr inExpr = translation { + let! bName = State.gets (fun context -> context.Namer.LetStart var.Name) + + let! vDecl = translation { + match expr with + | DerivedPatterns.SpecificCall <@@ local @@> (_, _, _) -> + let! vType = Type.translate var.Type + return VarDecl(vType, bName, None, spaceModifier = Local) + | DerivedPatterns.SpecificCall <@@ localArray @@> (_, _, [arg]) -> + let! expr = translateCond arg + let arrayLength = + match expr with + | :? Const as c -> int c.Val + | other -> raise <| InvalidKernelException(sprintf "Calling localArray with a non-const argument %A" other) + let! arrayType = Type.translate var.Type |> State.using (fun ctx -> { ctx with ArrayKind = CArrayDecl arrayLength }) + return VarDecl(arrayType, bName, None, spaceModifier = Local) + | Patterns.DefaultValue _ -> + let! vType = Type.translate var.Type + return VarDecl(vType, bName, None) + | _ -> return! translateBinding var bName expr + } + + do! State.modify (fun context -> context.VarDecls.Add vDecl; context) + do! State.modify (fun context -> context.Namer.LetIn var.Name; context) + + let! sb = State.gets (fun context -> context.VarDecls) + let! res = translate inExpr |> State.using clearContext + + match res with + | :? StatementBlock as s -> sb.AddRange s.Statements + | _ -> sb.Add(res :?> Statement<_>) + + do! State.modify (fun context -> context.Namer.LetOut(); context) + + return StatementBlock sb :> Node<_> + } + + let private translateProvidedCall expr = translation { + let rec traverse expr args = translation { + match expr with + | Patterns.Value (calledName, sType) -> + match sType.Name.ToLowerInvariant() with + | "string" -> return (calledName :?> string), args + | _ -> return raise <| TranslationFailedException(sprintf "Failed to parse provided call, expected string call name: %O" expr) + | Patterns.Sequential (expr1, expr2) -> + let! updatedArgs = translation { + match expr2 with + | Patterns.Value (null, _) -> return args // the last item in the sequence is null + | _ -> + let! a = translateAsExpr expr2 + return a :: args + } + return! traverse expr1 updatedArgs + | _ -> return raise <| TranslationFailedException(sprintf "Failed to parse provided call: %O" expr) + } + + let! m = traverse expr [] + return FunCall m :> Node<_> + } + let translate expr = translation { + let toNode (x: #Node<_>) = translation { + return x :> Node<_> + } + match expr with | Patterns.AddressOf expr -> return raise <| InvalidKernelException(sprintf "AdressOf is not suported: %O" expr) | Patterns.AddressSet expr -> return raise <| InvalidKernelException(sprintf "AdressSet is not suported: %O" expr) | Patterns.Application (expr1, expr2) -> - let! (e, appling) = translateApplication expr1 expr2 - if appling then + let! (e, applying) = translateApplication expr1 expr2 + if applying then return! translate e else - let! r = translateApplicationFun expr1 expr2 - return r :> Node<_> + return! translateApplicationFun expr1 expr2 >>= toNode - | DerivedPatterns.SpecificCall <@@ PrintfReplacer.print @@> (_, _, args) -> + | DerivedPatterns.SpecificCall <@@ print @@> (_, _, args) -> match args with | [ Patterns.ValueWithName (argTypes, _, _); Patterns.ValueWithName (formatStr, _, _); @@ -551,32 +698,36 @@ module rec Body = ) -> return! translate expr - | Patterns.Call (exprOpt, mInfo, args) -> - let! r = translateCall exprOpt mInfo args - return r :> Node<_> + | DerivedPatterns.SpecificCall <@ LanguagePrimitives.GenericOne @> (_, [onType], _) -> + let! type' = Type.translate onType + let value = + Expr.Call( + Utils.makeGenericMethodCall [onType] <@ LanguagePrimitives.GenericOne @>, + List.empty + ).EvaluateUntyped().ToString() + + return Const(type', value) :> Node<_> + + | Patterns.Call (exprOpt, mInfo, args) -> return! translateCall exprOpt mInfo args >>= toNode | Patterns.Coerce (expr, sType) -> return raise <| InvalidKernelException(sprintf "Coerce is not suported: %O" expr) | Patterns.DefaultValue sType -> return raise <| InvalidKernelException(sprintf "DefaulValue is not suported: %O" expr) + | Patterns.FieldGet (exprOpt, fldInfo) -> match exprOpt with - | Some expr -> - let! r = translateStructFieldGet expr fldInfo.Name - return r :> Node<_> + | Some expr -> return! translateStructFieldGet expr fldInfo.Name >>= toNode | None -> return raise <| InvalidKernelException(sprintf "FieldGet for empty host is not suported. Field: %A" fldInfo.Name) + | Patterns.FieldSet (exprOpt, fldInfo, expr) -> match exprOpt with - | Some e -> - let! r = translateFieldSet e fldInfo.Name expr - return r :> Node<_> + | Some e -> return! translateFieldSet e fldInfo.Name expr >>= toNode | None -> return raise <| InvalidKernelException(sprintf "Fileld set with empty host is not supported. Field: %A" fldInfo) - | Patterns.ForIntegerRangeLoop (i, from, _to, _do) -> - let! r = translateForIntegerRangeLoop i from _to _do - return r :> Node<_> - | Patterns.IfThenElse (cond, thenExpr, elseExpr) -> - let! r = translateIf cond thenExpr elseExpr - return r :> Node<_> - | Patterns.Lambda (var, _expr) -> - // translateLambda var expr targetContext - return raise <| InvalidKernelException(sprintf "Lambda is not suported: %A" expr) + + | ForLoopWithStep (loopVar, (start, step, finish), loopBody) -> return! translateForLoop loopVar start finish (Some step) loopBody >>= toNode + | ForLoop (loopVar, (start, finish), loopBody) -> return! translateForLoop loopVar start finish None loopBody >>= toNode + | Patterns.ForIntegerRangeLoop (loopVar, start, finish, loopBody) -> return! translateForLoop loopVar start finish None loopBody >>= toNode + | Patterns.IfThenElse (cond, thenExpr, elseExpr) -> return! translateIf cond thenExpr elseExpr >>= toNode + + | Patterns.Lambda (var, _expr) -> return raise <| InvalidKernelException(sprintf "Lambda is not suported: %A" expr) | Patterns.Let (var, expr, inExpr) -> match var.Name with | "___providedCallInfo" -> return! translateProvidedCall expr @@ -585,35 +736,31 @@ module rec Body = | Patterns.LetRecursive (bindings, expr) -> return raise <| InvalidKernelException(sprintf "LetRecursive is not suported: %O" expr) | Patterns.NewArray (sType, exprs) -> return raise <| InvalidKernelException(sprintf "NewArray is not suported: %O" expr) | Patterns.NewDelegate (sType, vars, expr) -> return raise <| InvalidKernelException(sprintf "NewDelegate is not suported: %O" expr) + | Patterns.NewObject (constrInfo, exprs) -> let! context = State.get - let p = constrInfo.GetParameters() - let p2 = constrInfo.GetMethodBody() - // а если перегруженный конструктор? (отсальное нули) - if context.UserDefinedTypes.Contains(constrInfo.DeclaringType) then - let! structInfo = State.gets (fun context -> context.StructDecls.[constrInfo.DeclaringType]) - let cArgs = exprs |> List.map (fun x -> translation { return! translateAsExpr x }) - let res = NewStruct<_>(structInfo, cArgs |> List.map (State.eval context)) - return res :> Node<_> - else - return raise <| InvalidKernelException(sprintf "NewObject is not suported: %O" expr) + // let p = constrInfo. GetParameters() + // let p2 = constrInfo.GetMethodBody() + let! structInfo = Type.translate constrInfo.DeclaringType + let cArgs = exprs |> List.map (fun x -> translation { return! translateAsExpr x }) + return NewStruct<_>(structInfo :?> StructType, cArgs |> List.map (State.eval context)) :> Node<_> + | Patterns.NewRecord (sType, exprs) -> let! context = State.get - let! structInfo = Type.translateStruct sType + let! structInfo = Type.translate sType let cArgs = exprs |> List.map (fun x -> translation { return! translateAsExpr x }) - return NewStruct<_>(structInfo, cArgs |> List.map (State.eval context)) :> Node<_> + return NewStruct<_>(structInfo :?> StructType, cArgs |> List.map (State.eval context)) :> Node<_> + | Patterns.NewTuple (exprs) -> let! context = State.get - let! tupleDecl = Type.translateTuple expr.Type + let! tupleDecl = Type.translate expr.Type let cArgs = exprs |> List.map (fun x -> translateAsExpr x) - return NewStruct<_>(tupleDecl, cArgs |> List.map (State.eval context)) :> Node<_> + return NewStruct<_>(tupleDecl :?> StructType, cArgs |> List.map (State.eval context)) :> Node<_> + | Patterns.NewUnionCase (unionCaseInfo, exprs) -> let! context = State.get - let unionType = unionCaseInfo.DeclaringType - if not <| context.UserDefinedTypes.Contains(unionType) then - raise <| InvalidKernelException(sprintf "Union type %s is not registered" unionType.Name) - - let unionInfo = context.UnionDecls.[unionType] + let! unionInfo = Type.translate unionCaseInfo.DeclaringType + let unionInfo = unionInfo :?> DiscriminatedUnionType let tag = Const(unionInfo.Tag.Type, string unionCaseInfo.Tag) :> Expression<_> let args = @@ -621,125 +768,51 @@ module rec Body = | None -> [] | Some field -> let structArgs = exprs |> List.map (fun x -> translateAsExpr x) |> List.map (State.eval context) - let data = - NewUnion( - unionInfo.Data.Type :?> UnionClInplaceType<_>, - field.Name, - NewStruct(field.Type :?> StructType<_>, structArgs) - ) - [ data :> Expression<_> ] + NewUnion( + unionInfo.Data.Type :?> UnionClInplaceType<_>, + field.Name, + NewStruct(field.Type :?> StructType<_>, structArgs) + ) :> Expression<_> + |> List.singleton return NewStruct(unionInfo, tag :: args) :> Node<_> - | Patterns.PropertyGet (exprOpt, propInfo, exprs) -> - let! res = translatePropGet exprOpt propInfo exprs - return res :> Node<_> - | Patterns.PropertySet (exprOpt, propInfo, exprs, expr) -> - let! res = translatePropSet exprOpt propInfo exprs expr - return res :> Node<_> - | Patterns.Sequential (expr1, expr2) -> - let! res = translateSeq expr1 expr2 - return res :> Node<_> + + | Patterns.PropertyGet (exprOpt, propInfo, exprs) -> return! translatePropGet exprOpt propInfo exprs >>= toNode + | Patterns.PropertySet (exprOpt, propInfo, exprs, expr) -> return! translatePropSet exprOpt propInfo exprs expr >>= toNode + | Patterns.Sequential (expr1, expr2) -> return! translateSeq expr1 expr2 >>= toNode | Patterns.TryFinally (tryExpr, finallyExpr) -> return raise <| InvalidKernelException(sprintf "TryFinally is not suported: %O" expr) | Patterns.TryWith (expr1, var1, expr2, var2, expr3) -> return raise <| InvalidKernelException(sprintf "TryWith is not suported: %O" expr) - | Patterns.TupleGet (expr, i) -> - let! r = translateStructFieldGet expr ("_" + (string (i + 1))) - return r :> Node<_> + | Patterns.TupleGet (expr, i) -> return! translateStructFieldGet expr ("_" + (string (i + 1))) >>= toNode | Patterns.TypeTest (expr, sType) -> return raise <| InvalidKernelException(sprintf "TypeTest is not suported: %O" expr) + | Patterns.UnionCaseTest (expr, unionCaseInfo) -> - let! context = State.get - let unionDecl = context.UnionDecls.[expr.Type] + let! unionInfo = Type.translate unionCaseInfo.DeclaringType + let unionInfo = unionInfo :?> DiscriminatedUnionType let! unionVarExpr = translateAsExpr expr - let unionGetTagExpr = FieldGet(unionVarExpr, unionDecl.Tag.Name) :> Expression<_> - let tagExpr = Const(unionDecl.Tag.Type, string unionCaseInfo.Tag) :> Expression<_> + let unionGetTagExpr = FieldGet(unionVarExpr, unionInfo.Tag.Name) :> Expression<_> + // NOTE Const pog for genericOne + let tagExpr = Const(unionInfo.Tag.Type, string unionCaseInfo.Tag) :> Expression<_> + + return Binop(EQ, unionGetTagExpr, tagExpr) :> Node<_> - return Binop(BOp.EQ, unionGetTagExpr, tagExpr) :> Node<_> - | Patterns.ValueWithName (_obj, sType, name) -> + | Patterns.ValueWithName (obj', sType, name) -> let! context = State.get // Here is the only use of TranslationContext.InLocal if sType.ToString().EndsWith "[]" (*&& not context.InLocal*) then context.Namer.AddVar name - let! res = translateValue _obj sType + let! res = translateValue obj' sType context.TopLevelVarsDecls.Add( VarDecl(res.Type, name, Some(res :> Expression<_>), AddressSpaceQualifier.Constant) ) let var = Var(name, sType) - let! res = translateVar var - return res :> Node<_> + return! translateVar var >>= toNode else - let! res = translateValue _obj sType - return res :> Node<_> - | Patterns.Value (_obj, sType) -> - let! res = translateValue _obj sType - return res :> Node<_> - | Patterns.Var var -> - let! res = translateVar var - return res :> Node<_> - | Patterns.VarSet (var, expr) -> - let! res = translateVarSet var expr - return res :> Node<_> - | Patterns.WhileLoop (condExpr, bodyExpr) -> - let! r = translateWhileLoop condExpr bodyExpr - return r :> Node<_> - | _ -> return raise <| InvalidKernelException(sprintf "Folowing expression inside kernel is not supported:\n%O" expr) - } - - let private translateLet var expr inExpr = translation { - let! bName = State.gets (fun context -> context.Namer.LetStart var.Name) + return! translateValue obj' sType >>= toNode - let! vDecl = translation { - match expr with - | DerivedPatterns.SpecificCall <@@ local @@> (_, _, _) -> - let! vType = Type.translate var.Type - return VarDecl(vType, bName, None, spaceModifier = Local) - | DerivedPatterns.SpecificCall <@@ localArray @@> (_, _, [arg]) -> - let! expr = translateCond arg - let arrayLength = - match expr with - | :? Const as c -> int c.Val - | other -> raise <| InvalidKernelException(sprintf "Calling localArray with a non-const argument %A" other) - let! arrayType = Type.translate var.Type |> State.using (fun ctx -> { ctx with ArrayKind = CArrayDecl arrayLength }) - return VarDecl(arrayType, bName, None, spaceModifier = Local) - | Patterns.DefaultValue _ -> - let! vType = Type.translate var.Type - return VarDecl(vType, bName, None) - | _ -> return! translateBinding var bName expr - } - - do! State.modify (fun context -> context.VarDecls.Add vDecl; context) - do! State.modify (fun context -> context.Namer.LetIn var.Name; context) - - let! sb = State.gets (fun context -> context.VarDecls) - let! res = translate inExpr |> State.using clearContext - - match res with - | :? StatementBlock as s -> sb.AddRange s.Statements - | _ -> sb.Add(res :?> Statement<_>) - - do! State.modify (fun context -> context.Namer.LetOut(); context) - - return StatementBlock sb :> Node<_> - } - - let private translateProvidedCall expr = translation { - let rec traverse expr args = translation { - match expr with - | Patterns.Value (calledName, sType) -> - match sType.Name.ToLowerInvariant() with - | "string" -> return (calledName :?> string), args - | _ -> return raise <| TranslationFailedException(sprintf "Failed to parse provided call, expected string call name: %O" expr) - | Patterns.Sequential (expr1, expr2) -> - let! updatedArgs = translation { - match expr2 with - | Patterns.Value (null, _) -> return args // the last item in the sequence is null - | _ -> - let! a = translateAsExpr expr2 - return a :: args - } - return! traverse expr1 updatedArgs - | _ -> return raise <| TranslationFailedException(sprintf "Failed to parse provided call: %O" expr) - } - - let! m = traverse expr [] - return FunCall m :> Node<_> + | Patterns.Value (obj', sType) -> return! translateValue obj' sType >>= toNode + | Patterns.Var var -> return! translateVar var >>= toNode + | Patterns.VarSet (var, expr) -> return! translateVarSet var expr >>= toNode + | Patterns.WhileLoop (condExpr, bodyExpr) -> return! translateWhileLoop condExpr bodyExpr >>= toNode + | _ -> return raise <| InvalidKernelException(sprintf "Folowing expression inside kernel is not supported:\n%O" expr) } diff --git a/src/Brahma.FSharp.OpenCL.Translator/Brahma.FSharp.OpenCL.Translator.fsproj b/src/Brahma.FSharp.OpenCL.Translator/Brahma.FSharp.OpenCL.Translator.fsproj index 215e7d81..91cf5c87 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Brahma.FSharp.OpenCL.Translator.fsproj +++ b/src/Brahma.FSharp.OpenCL.Translator/Brahma.FSharp.OpenCL.Translator.fsproj @@ -16,14 +16,15 @@ - + + @@ -43,4 +44,4 @@ - \ No newline at end of file + diff --git a/src/Brahma.FSharp.OpenCL.Translator/Exceptions.fs b/src/Brahma.FSharp.OpenCL.Translator/Exceptions.fs index d6316125..11e36848 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Exceptions.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Exceptions.fs @@ -1,7 +1,19 @@ namespace Brahma.FSharp.OpenCL.Translator +open System + /// The exception that is thrown when the kernel has invalid format. -exception InvalidKernelException of string +type InvalidKernelException = + inherit Exception + + new() = { inherit Exception() } // + new(message: string) = { inherit Exception(message) } + new(message: string, inner: Exception) = { inherit Exception(message, inner) } // /// The exception that is thrown when the unexpected error occured during the translation. -exception TranslationFailedException of string +type TranslationFailedException = + inherit Exception + + new() = { inherit Exception() } // + new(message: string) = { inherit Exception(message) } + new(message: string, inner: Exception) = { inherit Exception(message, inner) } diff --git a/src/Brahma.FSharp.OpenCL.Translator/Methods.fs b/src/Brahma.FSharp.OpenCL.Translator/Methods.fs index b910c805..64b4bfe8 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Methods.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Methods.fs @@ -4,7 +4,7 @@ open Microsoft.FSharp.Quotations open Brahma.FSharp.OpenCL.AST [] -type Method(var: Var, expr: Expr, context: TranslationContext>) = +type Method(var: Var, expr: Expr) = member this.FunVar = var member this.FunExpr = expr @@ -33,121 +33,123 @@ type Method(var: Var, expr: Expr, context: TranslationContext StatementBlock * TranslationContext> - default this.TranslateBody(args, body) = - let (newBody, context) = - let clonedContext = context.Copy() + abstract TranslateBody : Var list * Expr -> State> + default this.TranslateBody(args, body) = translation { + let! context = State.get - clonedContext.Namer.LetIn() - args |> List.iter (fun v -> clonedContext.Namer.AddVar v.Name) + context.Namer.LetIn() + args |> List.iter (fun v -> context.Namer.AddVar v.Name) - Body.translate body |> State.run clonedContext + let! newBody = Body.translate body - match newBody with - | :? StatementBlock as sb -> sb - | :? Statement as s -> StatementBlock <| ResizeArray [s] - | _ -> failwithf "Incorrect function body: %A" newBody - , context + return + match newBody with + | :? StatementBlock as sb -> sb + | :? Statement as s -> StatementBlock <| ResizeArray [s] + | _ -> failwithf "Incorrect function body: %A" newBody + } - abstract TranslateArgs : Var list * string list * string list * TranslationContext> -> FunFormalArg list + abstract TranslateArgs : Var list * string list * string list -> State list> - abstract BuildFunction : FunFormalArg list * StatementBlock * TranslationContext> -> ITopDef + abstract BuildFunction : FunFormalArg list * StatementBlock -> State> - abstract GetPragmas : TranslationContext> -> ITopDef list - default this.GetPragmas(context) = - let pragmas = ResizeArray() + abstract GetTopLevelVarDecls : unit -> State list> + default this.GetTopLevelVarDecls() = translation { + let! context = State.get - if context.Flags.enableAtomic then - pragmas.Add(CLPragma CLGlobalInt32BaseAtomics :> ITopDef<_>) - pragmas.Add(CLPragma CLLocalInt32BaseAtomics :> ITopDef<_>) + return + context.TopLevelVarsDecls + |> Seq.cast<_> + |> List.ofSeq + } - if context.Flags.enableFP64 then - pragmas.Add(CLPragma CLFP64) + abstract Translate : string list * string list -> State list> + default this.Translate(globalVars, localVars) = translation { + do! State.modify (fun context -> context.WithNewLocalContext()) - List.ofSeq pragmas - - abstract GetTopLevelVarDecls : TranslationContext> -> ITopDef list - default this.GetTopLevelVarDecls(context) = - context.TopLevelVarsDecls - |> Seq.cast<_> - |> List.ofSeq - - abstract Translate : string list * string list -> ITopDef list - default this.Translate(globalVars, localVars) = match expr with | DerivedPatterns.Lambdas (args, body) -> let args = List.collect id args - let (translatedBody, context) = this.TranslateBody(args, body) - let translatedArgs = this.TranslateArgs(args, globalVars, localVars, context) - let func = this.BuildFunction(translatedArgs, translatedBody, context) - let pragmas = this.GetPragmas(context) - let topLevelVarDecls = this.GetTopLevelVarDecls(context) + let! translatedBody = this.TranslateBody(args, body) + let! translatedArgs = this.TranslateArgs(args, globalVars, localVars) + let! func = this.BuildFunction(translatedArgs, translatedBody) + let! topLevelVarDecls = this.GetTopLevelVarDecls() - pragmas - @ topLevelVarDecls - @ [func] + return topLevelVarDecls @ [func] - | _ -> failwithf "Incorrect OpenCL quotation: %A" expr + | _ -> return failwithf "Incorrect OpenCL quotation: %A" expr + } override this.ToString() = sprintf "%A\n%A" var expr -type KernelFunc(var: Var, expr: Expr, context: TranslationContext>) = - inherit Method(var, expr, context) +type KernelFunc(var: Var, expr: Expr) = + inherit Method(var, expr) + + override this.TranslateArgs(args, _, _) = translation { + let! context = State.get - override this.TranslateArgs(args, _, _, context) = let brahmaDimensionsTypes = [ Range1D_ Range2D_ Range3D_ ] - args - |> List.filter - (fun (variable: Var) -> - brahmaDimensionsTypes - |> (not << List.contains (variable.Type.Name.ToLowerInvariant())) - ) - |> List.map - (fun variable -> + return + args + |> List.filter + (fun (variable: Var) -> + brahmaDimensionsTypes + |> (not << List.contains (variable.Type.Name.ToLowerInvariant())) + ) + |> List.map + (fun variable -> + let vType = Type.translate variable.Type |> State.eval context + let declSpecs = DeclSpecifierPack(typeSpecifier = vType) + + if vType :? RefType<_> then + declSpecs.AddressSpaceQualifier <- Global + + FunFormalArg(declSpecs, variable.Name) + ) + } + + override this.BuildFunction(args, body) = translation { + let retFunType = PrimitiveType Void :> Type<_> + let declSpecs = DeclSpecifierPack(typeSpecifier = retFunType, funQualifier = Kernel) + return FunDecl(declSpecs, var.Name, args, body) :> ITopDef<_> + } + +type Function(var: Var, expr: Expr) = + inherit Method(var, expr) + + override this.TranslateArgs(args, globalVars, localVars) = translation { + let! context = State.get + + return + args + |> List.map (fun variable -> let vType = Type.translate variable.Type |> State.eval context let declSpecs = DeclSpecifierPack(typeSpecifier = vType) - if vType :? RefType<_> then + if + vType :? RefType<_> && + globalVars |> List.contains variable.Name + then declSpecs.AddressSpaceQualifier <- Global + elif + vType :? RefType<_> && + localVars |> List.contains variable.Name + then + declSpecs.AddressSpaceQualifier <- Local FunFormalArg(declSpecs, variable.Name) ) + } + + override this.BuildFunction(args, body) = translation { + let! context = State.get - override this.BuildFunction(args, body, _) = - let retFunType = PrimitiveType Void :> Type<_> - let declSpecs = DeclSpecifierPack(typeSpecifier = retFunType, funQualifier = Kernel) - FunDecl(declSpecs, var.Name, args, body) :> ITopDef<_> - -type Function(var: Var, expr: Expr, context: TranslationContext>) = - inherit Method(var, expr, context) - - override this.TranslateArgs(args, globalVars, localVars, context) = - args - |> List.map (fun variable -> - let vType = Type.translate variable.Type |> State.eval context - let declSpecs = DeclSpecifierPack(typeSpecifier = vType) - - if - vType :? RefType<_> && - globalVars |> List.contains variable.Name - then - declSpecs.AddressSpaceQualifier <- Global - elif - vType :? RefType<_> && - localVars |> List.contains variable.Name - then - declSpecs.AddressSpaceQualifier <- Local - - FunFormalArg(declSpecs, variable.Name) - ) - - override this.BuildFunction(args, body, context) = let retFunType = Type.translate var.Type |> State.eval context let declSpecs = DeclSpecifierPack(typeSpecifier = retFunType) let partAST = @@ -155,42 +157,50 @@ type Function(var: Var, expr: Expr, context: TranslationContext as t when t.Type = Void -> body :> Statement<_> | _ -> this.AddReturn(body) - FunDecl(declSpecs, var.Name, args, partAST) :> ITopDef<_> + return FunDecl(declSpecs, var.Name, args, partAST) :> ITopDef<_> + } + +type AtomicFunc(var: Var, expr: Expr, qual: AddressSpaceQualifier) = + inherit Method(var, expr) -type AtomicFunc(var: Var, expr: Expr, qual: AddressSpaceQualifier, context: TranslationContext>) = - inherit Method(var, expr, context) + override this.TranslateArgs(args, globalVars, localVars) = translation { + let! context = State.get - override this.TranslateArgs(args, globalVars, localVars, context) = let firstNonMutexIdx = args |> List.tryFindIndex (fun v -> not <| v.Name.EndsWith "Mutex") |> Option.defaultValue 0 - args - |> List.mapi - (fun i variable -> - let vType = Type.translate variable.Type |> State.eval context - let declSpecs = DeclSpecifierPack(typeSpecifier = vType) - - if i = firstNonMutexIdx then - declSpecs.AddressSpaceQualifier <- qual - elif - vType :? RefType<_> && - globalVars |> List.contains variable.Name - then - declSpecs.AddressSpaceQualifier <- Global - elif - vType :? RefType<_> && - localVars |> List.contains variable.Name - then - declSpecs.AddressSpaceQualifier <- Local - - FunFormalArg(declSpecs, variable.Name) - ) + return + args + |> List.mapi + (fun i variable -> + let vType = Type.translate variable.Type |> State.eval context + let declSpecs = DeclSpecifierPack(typeSpecifier = vType) + + if i = firstNonMutexIdx then + declSpecs.AddressSpaceQualifier <- qual + elif + vType :? RefType<_> && + globalVars |> List.contains variable.Name + then + declSpecs.AddressSpaceQualifier <- Global + elif + vType :? RefType<_> && + localVars |> List.contains variable.Name + then + declSpecs.AddressSpaceQualifier <- Local + + FunFormalArg(declSpecs, variable.Name) + ) + } + + override this.BuildFunction(args, body) = translation { + let! context = State.get - override this.BuildFunction(args, body, context) = let retFunType = Type.translate var.Type |> State.eval context let declSpecs = DeclSpecifierPack(typeSpecifier = retFunType) let partAST = this.AddReturn body - FunDecl(declSpecs, var.Name, args, partAST) :> ITopDef<_> + return FunDecl(declSpecs, var.Name, args, partAST) :> ITopDef<_> + } diff --git a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/AtomicProcessor.fs b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/AtomicProcessor.fs index 8290b5f7..ae79561b 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/AtomicProcessor.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/AtomicProcessor.fs @@ -32,17 +32,17 @@ module AtomicProcessor = let inline private atomicOr p v = (|||) !p v let inline private atomicXor p v = (^^^) !p v - let private atomicAddInfo = (Utils.getMethodInfoOfLambda <@ atomicAdd @>).GetGenericMethodDefinition() - let private atomicSubInfo = (Utils.getMethodInfoOfLambda <@ atomicSub @>).GetGenericMethodDefinition() - let private atomicIncInfo = (Utils.getMethodInfoOfLambda <@ atomicInc @>).GetGenericMethodDefinition() - let private atomicDecInfo = (Utils.getMethodInfoOfLambda <@ atomicDec @>).GetGenericMethodDefinition() - let private atomicXchgInfo = (Utils.getMethodInfoOfLambda <@ atomicXchg @>).GetGenericMethodDefinition() - let private atomicCmpxchgInfo = (Utils.getMethodInfoOfLambda <@ atomicCmpxchg @>).GetGenericMethodDefinition() - let private atomicMinInfo = (Utils.getMethodInfoOfLambda <@ atomicMin @>).GetGenericMethodDefinition() - let private atomicMaxInfo = (Utils.getMethodInfoOfLambda <@ atomicMax @>).GetGenericMethodDefinition() - let private atomicAndInfo = (Utils.getMethodInfoOfLambda <@ atomicAnd @>).GetGenericMethodDefinition() - let private atomicOrInfo = (Utils.getMethodInfoOfLambda <@ atomicOr @>).GetGenericMethodDefinition() - let private atomicXorInfo = (Utils.getMethodInfoOfLambda <@ atomicXor @>).GetGenericMethodDefinition() + let private atomicAddInfo = (Utils.getMethodInfoOfCall <@ atomicAdd @>).GetGenericMethodDefinition() + let private atomicSubInfo = (Utils.getMethodInfoOfCall <@ atomicSub @>).GetGenericMethodDefinition() + let private atomicIncInfo = (Utils.getMethodInfoOfCall <@ atomicInc @>).GetGenericMethodDefinition() + let private atomicDecInfo = (Utils.getMethodInfoOfCall <@ atomicDec @>).GetGenericMethodDefinition() + let private atomicXchgInfo = (Utils.getMethodInfoOfCall <@ atomicXchg @>).GetGenericMethodDefinition() + let private atomicCmpxchgInfo = (Utils.getMethodInfoOfCall <@ atomicCmpxchg @>).GetGenericMethodDefinition() + let private atomicMinInfo = (Utils.getMethodInfoOfCall <@ atomicMin @>).GetGenericMethodDefinition() + let private atomicMaxInfo = (Utils.getMethodInfoOfCall <@ atomicMax @>).GetGenericMethodDefinition() + let private atomicAndInfo = (Utils.getMethodInfoOfCall <@ atomicAnd @>).GetGenericMethodDefinition() + let private atomicOrInfo = (Utils.getMethodInfoOfCall <@ atomicOr @>).GetGenericMethodDefinition() + let private atomicXorInfo = (Utils.getMethodInfoOfCall <@ atomicXor @>).GetGenericMethodDefinition() let private modifyFirstOfList f lst = match lst with @@ -192,70 +192,36 @@ module AtomicProcessor = let baseFuncBody = match lambdaBody with - | DerivedPatterns.SpecificCall <@ inc @> (_, onType :: _, _) -> - failwithf "Atomic inc for %O is not suppotred" onType - - | DerivedPatterns.SpecificCall <@ dec @> (_, onType :: _, _) -> - failwithf "Atomic inc for %O is not suppotred" onType - - // | DerivedPatterns.SpecificCall <@ inc @> (_, onType :: _, [Patterns.Var p]) -> - // Expr.Call( - // (Utils.getMethodInfoOfLambda <@ (+) @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType, onType, onType), - - // [ - // Expr.Var p; - // Expr.Call( - // (Utils.getMethodInfoOfLambda <@ unbox @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType), - - // Expr.Value( - // Expr.Call( - // (Utils.getMethodInfoOfCall <@ GenericOne @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType), - // List.empty - // ).EvaluateUntyped() - // ) - // |> List.singleton - // ) - // ] - // ) - - // | DerivedPatterns.SpecificCall <@ dec @> (_, onType :: _, [Patterns.Var p]) -> - // Expr.Call( - // (Utils.getMethodInfoOfCall <@ (-) @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType, onType, onType), - - // [ - // Expr.Var p; - // Expr.Call( - // (Utils.getMethodInfoOfCall <@ unbox @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType), - - // Expr.Value( - // Expr.Call( - // (Utils.getMethodInfoOfCall <@ GenericOne @>) - // .GetGenericMethodDefinition() - // .MakeGenericMethod(onType), - // List.empty - // ).EvaluateUntyped() - // ) - // |> List.singleton - // ) - // ] - // ) + | DerivedPatterns.SpecificCall <@ inc @> (_, onType :: _, [Patterns.Var p]) -> + Expr.Call( + Utils.makeGenericMethodCall [onType; onType; onType] <@ (+) @>, + [ + Expr.Var p; + Expr.Call( + Utils.makeGenericMethodCall [onType] <@ GenericOne @>, + List.empty + ) + ] + ) + + | DerivedPatterns.SpecificCall <@ dec @> (_, onType :: _, [Patterns.Var p]) -> + Expr.Call( + Utils.makeGenericMethodCall [onType; onType; onType] <@ (-) @>, + [ + Expr.Var p; + Expr.Call( + Utils.makeGenericMethodCall [onType] <@ GenericOne @>, + List.empty + ) + ] + ) | DerivedPatterns.SpecificCall <@ xchg @> (_, _, [Patterns.Var p; Patterns.Var value]) -> Expr.Var value | DerivedPatterns.SpecificCall <@ cmpxchg @> (_, onType :: _, [Patterns.Var p; Patterns.Var cmp; Patterns.Var value]) -> Expr.IfThenElse( - Expr.Call((Utils.getMethodInfoOfLambda <@ (=) @>).GetGenericMethodDefinition().MakeGenericMethod(onType), [Expr.Var p; Expr.Var cmp]), + Expr.Call(Utils.makeGenericMethodCall [onType] <@ (=) @>, [Expr.Var p; Expr.Var cmp]), Expr.Var value, Expr.Var p ) @@ -320,7 +286,7 @@ module AtomicProcessor = | DerivedPatterns.SpecificCall <@ IntrinsicFunctions.GetArray @> (_, _, [Patterns.Var _; idx]) -> Expr.Call( - Utils.getMethodInfoOfLambda <@ IntrinsicFunctions.GetArray @>, + Utils.getMethodInfoOfCall <@ IntrinsicFunctions.GetArray @>, [Expr.Var mutexVar; idx] ) @@ -361,7 +327,7 @@ module AtomicProcessor = flag <- false // HACK needed for nvidia, but broken for intel cpu //barrier () - barrier () + barrierLocal () @@>, Expr.Var oldValueVar ) @@ -459,7 +425,7 @@ module AtomicProcessor = | Some mutexVar -> Expr.Let( mutexVar, - Expr.Call(Utils.getMethodInfoOfLambda <@ localArray @>, args), + Expr.Call(Utils.getMethodInfoOfCall <@ localArray @>, args), Expr.Sequential( <@@ if Anchors._localID0 = 0 then @@ -470,7 +436,7 @@ module AtomicProcessor = Expr.Value 0, <@@ (%%args.[0] : int) - 1 @@>, Expr.Call( - Utils.getMethodInfoOfLambda <@ IntrinsicFunctions.SetArray @>, + Utils.getMethodInfoOfCall <@ IntrinsicFunctions.SetArray @>, [ Expr.Var mutexVar Expr.Var i @@ -479,7 +445,7 @@ module AtomicProcessor = ) ) ) - barrier () + barrierLocal () @@>, inExpr ) diff --git a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/TransformMinMax.fs b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/TransformMinMax.fs new file mode 100644 index 00000000..4c71bf27 --- /dev/null +++ b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/TransformMinMax.fs @@ -0,0 +1,42 @@ +namespace Brahma.FSharp.OpenCL.Translator.QuotationTransformers + +open FSharp.Quotations +open Brahma.FSharp.OpenCL.Translator + +[] +module TransformerMinMax = + let helper (expr: Expr) (type': System.Type) (x: Expr) (y: Expr) = + let cachedXVar = Var("tempVarX", type') + let cachedYVar = Var("tempVarY", type') + + Expr.Let( + cachedXVar, + x, + Expr.Let( + cachedYVar, + y, + Expr.IfThenElse( + Expr.Call( + Utils.makeGenericMethodCall [type'] expr, + [Expr.Var cachedXVar; Expr.Var cachedYVar] + ), + Expr.Var cachedXVar, + Expr.Var cachedYVar + ) + ) + ) + + let rec transformMinMax (expr: Expr) = + match expr with + | DerivedPatterns.SpecificCall <@@ max @@> (_, [genericParam], [x; y]) -> + helper <@@ (>) @@> genericParam x y + + | DerivedPatterns.SpecificCall <@@ min @@> (_, [genericParam], [x; y]) -> + helper <@@ (<) @@> genericParam x y + + | ExprShape.ShapeVar _ -> + expr + | ExprShape.ShapeLambda (x, body) -> + Expr.Lambda(x, transformMinMax body) + | ExprShape.ShapeCombination(combo, exprList) -> + ExprShape.RebuildShapeCombination(combo, List.map transformMinMax exprList) diff --git a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Transformer.fs b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Transformer.fs index 84a3cab5..5dcd72c1 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Transformer.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Transformer.fs @@ -24,8 +24,9 @@ module Transformer = let preprocessQuotation expr = replacePrintf expr /// Returns kernel and other methods - let transformQuotation (expr: Expr) (translatorOptions: TranslatorOption list) = + let transformQuotation (expr: Expr) = expr + // |> transformMinMax |> processAtomic |> replacePrintf |> makeVarNameUnique diff --git a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Patterns.fs b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Patterns.fs index 433b4df3..bfc4d73f 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Patterns.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Patterns.fs @@ -92,7 +92,7 @@ module Patterns = let rec private uncurryLambda (expr: Expr) = match expr with | ExprShape.ShapeLambda (var, body) -> - let args, innerBody = uncurryLambda body + let (args, innerBody) = uncurryLambda body var :: args, innerBody | _ -> [], expr diff --git a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Utils.fs b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Utils.fs index fbc6aa81..b6db8d30 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Utils.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/QuotationTransformers/Utilities/Utils.fs @@ -116,16 +116,6 @@ module Utils = Expr.Call (newMethodInfo, [reference; value]) | _ -> failwithf "createReferenceSetCall: (:=) is not more a Call expression" - let getMethodInfoOfLambda (expr: Expr) = - match expr with - | DerivedPatterns.Lambdas (args, Patterns.Call (_, mInfo, _)) -> mInfo - | _ -> failwithf "Expression is not lambda, but %O" expr - - let getMethodInfoOfCall (expr: Expr) = - match expr with - | Patterns.Call (_, mInfo, _) -> mInfo - | _ -> failwithf "Expression is not call, but %O" expr - let isGlobal (var: Var) = var.Type.Name.ToLower().StartsWith ClArray_ || var.Type.Name.ToLower().StartsWith ClCell_ diff --git a/src/Brahma.FSharp.OpenCL.Translator/TranslationContext.fs b/src/Brahma.FSharp.OpenCL.Translator/TranslationContext.fs index 1e8d261e..d5ecd133 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/TranslationContext.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/TranslationContext.fs @@ -23,51 +23,45 @@ type ArrayKind = | CPointer | CArrayDecl of size: int -type Flags() = - member val enableAtomic = false with get, set - member val enableFP64 = false with get, set +type Flag = + | EnableAtomic + | EnableFP64 -type TranslatorOption = - | UseNativeBooleanType - | BoolAsBit +type TranslatorOptions() = + member val UseNativeBooleanType = false with get, set + member val BoolAsBit = false with get, set type TranslationContext<'lang, 'vDecl> = { TopLevelVarsDecls: ResizeArray<'vDecl> - UserDefinedTypes: HashSet - // NOTE is it necessary to have 3 dicts? - TupleDecls: Dictionary> - StructDecls: Dictionary> - UnionDecls: Dictionary> + CStructDecls: Dictionary> VarDecls: ResizeArray<'vDecl> Namer: Namer ArrayKind: ArrayKind - Flags: Flags - TranslatorOptions: TranslatorOption list + Flags: HashSet + TranslatorOptions: TranslatorOptions } - static member Create([] translatorOptions: TranslatorOption[]) = + static member Create() = { TopLevelVarsDecls = ResizeArray<'vDecl>() - UserDefinedTypes = HashSet() - TupleDecls = Dictionary>() - StructDecls = Dictionary>() - UnionDecls = Dictionary>() + CStructDecls = Dictionary>() VarDecls = ResizeArray<'vDecl>() Namer = Namer() ArrayKind = CPointer - Flags = Flags() - TranslatorOptions = translatorOptions |> Array.toList + Flags = HashSet() + TranslatorOptions = TranslatorOptions() } - member this.Copy() = + member this.WithNewLocalContext() = { this with VarDecls = ResizeArray() Namer = Namer() + ArrayKind = CPointer } type TargetContext = TranslationContext> diff --git a/src/Brahma.FSharp.OpenCL.Translator/Translator.fs b/src/Brahma.FSharp.OpenCL.Translator/Translator.fs index fab5460c..a45bdd39 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Translator.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Translator.fs @@ -21,7 +21,7 @@ open Brahma.FSharp.OpenCL.Translator.QuotationTransformers open System open System.Collections.Generic -type FSQuotationToOpenCLTranslator([] translatorOptions: TranslatorOption[]) = +type FSQuotationToOpenCLTranslator(translatorOptions: TranslatorOptions) = let mainKernelName = "brahmaKernel" let lockObject = obj () @@ -40,18 +40,24 @@ type FSQuotationToOpenCLTranslator([] translatorOptions: TranslatorO let atomicApplicationsInfo = let atomicPointerArgQualifiers = Dictionary>() + let (|AtomicApplArgs|_|) (args: Expr list list) = + match args with + | [mutex] :: _ :: [[DerivedPatterns.SpecificCall <@ ref @> (_, _, [Patterns.ValidVolatileArg var])]] + | [mutex] :: [[DerivedPatterns.SpecificCall <@ ref @> (_, _, [Patterns.ValidVolatileArg var])]] -> Some (mutex, var) + | _ -> None + let rec go expr = match expr with | DerivedPatterns.Applications ( Patterns.Var funcVar, - [mutex] :: [DerivedPatterns.SpecificCall <@ ref @> (_, _, [Patterns.ValidVolatileArg var])] :: _ + AtomicApplArgs (_, volatileVar) ) when funcVar.Name.StartsWith "atomic" -> - if kernelArgumentsNames |> List.contains var.Name then + if kernelArgumentsNames |> List.contains volatileVar.Name then atomicPointerArgQualifiers.Add(funcVar, Global) - elif localVarsNames |> List.contains var.Name then + elif localVarsNames |> List.contains volatileVar.Name then atomicPointerArgQualifiers.Add(funcVar, Local) else failwith "Atomic pointer argument should be from local or global memory only" @@ -71,51 +77,55 @@ type FSQuotationToOpenCLTranslator([] translatorOptions: TranslatorO kernelArgumentsNames, localVarsNames, atomicApplicationsInfo - let constructMethods (expr: Expr) (functions: (Var * Expr) list) (atomicApplicationsInfo: Map>) context = - let kernelFunc = KernelFunc(Var(mainKernelName, expr.Type), expr, context) :> Method |> List.singleton + let constructMethods (expr: Expr) (functions: (Var * Expr) list) (atomicApplicationsInfo: Map>) = + let kernelFunc = KernelFunc(Var(mainKernelName, expr.Type), expr) :> Method |> List.singleton let methods = functions |> List.map (fun (var, expr) -> match atomicApplicationsInfo |> Map.tryFind var with - | Some qual -> AtomicFunc(var, expr, qual, context) :> Method - | None -> Function(var, expr, context) :> Method + | Some qual -> AtomicFunc(var, expr, qual) :> Method + | None -> Function(var, expr) :> Method ) methods @ kernelFunc - let translate expr' translatorOptions = + let translate expr' = let expr = preprocessQuotation expr' let context = TranslationContext.Create() // TODO: Extract quotationTransformer to translator - let (kernelExpr, functions) = transformQuotation expr translatorOptions + let (kernelExpr, functions) = transformQuotation expr let (globalVars, localVars, atomicApplicationsInfo) = collectData kernelExpr functions - let methods = constructMethods kernelExpr functions atomicApplicationsInfo context + let methods = constructMethods kernelExpr functions atomicApplicationsInfo let clFuncs = ResizeArray() for method in methods do - clFuncs.AddRange(method.Translate(globalVars, localVars)) + clFuncs.AddRange(method.Translate(globalVars, localVars) |> State.eval context) + + let pragmas = + let pragmas = ResizeArray() + + context.Flags + |> Seq.iter (fun (flag: Flag) -> + match flag with + | EnableAtomic -> + pragmas.Add(CLPragma CLGlobalInt32BaseAtomics :> ITopDef<_>) + pragmas.Add(CLPragma CLLocalInt32BaseAtomics :> ITopDef<_>) + | EnableFP64 -> + pragmas.Add(CLPragma CLFP64) + ) + + List.ofSeq pragmas let userDefinedTypes = - context.UserDefinedTypes - |> Seq.map - (fun type' -> - if context.StructDecls.ContainsKey type' then - context.StructDecls.[type'] - elif context.TupleDecls.ContainsKey type' then - context.TupleDecls.[type'] - elif context.UnionDecls.ContainsKey type' then - context.UnionDecls.[type'] :> StructType<_> - else - failwith "Something went wrong :( This error shouldn't occur" - ) + context.CStructDecls.Values |> Seq.map StructDecl |> Seq.cast> |> List.ofSeq - AST <| userDefinedTypes @ List.ofSeq clFuncs, + AST(pragmas @ userDefinedTypes @ List.ofSeq clFuncs), methods |> List.find (fun method -> method :? KernelFunc) |> fun kernel -> kernel.FunExpr @@ -124,4 +134,4 @@ type FSQuotationToOpenCLTranslator([] translatorOptions: TranslatorO member this.Translate(qExpr) = lock lockObject <| fun () -> - translate qExpr (List.ofArray translatorOptions) + translate qExpr diff --git a/src/Brahma.FSharp.OpenCL.Translator/Type.fs b/src/Brahma.FSharp.OpenCL.Translator/Type.fs index c33f888e..56479464 100644 --- a/src/Brahma.FSharp.OpenCL.Translator/Type.fs +++ b/src/Brahma.FSharp.OpenCL.Translator/Type.fs @@ -59,11 +59,11 @@ module rec Type = | Name "unit" -> return PrimitiveType(Void) :> Type | Name "float" | Name "double" -> - do! State.modify (fun ctx -> ctx.Flags.enableFP64 <- true; ctx) + do! State.modify (fun ctx -> ctx.Flags.Add EnableFP64 |> ignore; ctx) return PrimitiveType(Double) :> Type | Name "boolean" -> - match! State.gets (fun ctx -> ctx.TranslatorOptions |> List.contains UseNativeBooleanType) with + match! State.gets (fun ctx -> ctx.TranslatorOptions.UseNativeBooleanType) with | true -> return PrimitiveType(Bool) :> Type | false -> return PrimitiveType(BoolClAlias) :> Type @@ -100,14 +100,18 @@ module rec Type = let! translated = translateStruct type' return translated :> Type<_> + | _ when FSharpType.IsUnion type' -> + let! translated = translateUnion type' + return translated :> Type<_> + | other -> return failwithf "Unsupported kernel type: %A" other } let translateStruct (type': System.Type) = translation { let! context = State.get - if context.StructDecls.ContainsKey type' then - return context.StructDecls.[type'] + if context.CStructDecls.ContainsKey type' then + return context.CStructDecls.[type'] else let! fields = [ @@ -126,18 +130,17 @@ module rec Type = let fields = fields |> List.distinct - let! index = State.gets (fun ctx -> ctx.StructDecls.Count) + let! index = State.gets (fun ctx -> ctx.CStructDecls.Count) let structType = StructType(sprintf "struct%i" index, fields) - do! State.modify (fun context -> context.StructDecls.Add(type', structType); context) - context.UserDefinedTypes.Add type' |> ignore + do! State.modify (fun context -> context.CStructDecls.Add(type', structType); context) return structType } let translateTuple (type': System.Type) = translation { let! context = State.get - if context.StructDecls.ContainsKey type' then - return context.StructDecls.[type'] + if context.CStructDecls.ContainsKey type' then + return context.CStructDecls.[type'] else let genericTypeArguments = FSharpType.GetTupleElements type' |> List.ofArray @@ -153,46 +156,43 @@ module rec Type = }) |> State.collect - match! State.gets (fun ctx -> ctx.StructDecls.ContainsKey type') with - | true -> - return! State.gets (fun ctx -> ctx.StructDecls.[type']) - | false -> - let! index = State.gets (fun ctx -> ctx.StructDecls.Count) - let tupleDecl = StructType(sprintf "tuple%i" index, elements) - do! State.modify (fun ctx -> ctx.StructDecls.Add(type', tupleDecl); ctx) - context.UserDefinedTypes.Add type' |> ignore - return tupleDecl + let! index = State.gets (fun ctx -> ctx.CStructDecls.Count) + let tupleDecl = StructType(sprintf "tuple%i" index, elements) + do! State.modify (fun ctx -> ctx.CStructDecls.Add(type', tupleDecl); ctx) + return tupleDecl } let translateUnion (type': System.Type) = translation { - let name = type'.Name - - let notEmptyCases = - FSharpType.GetUnionCases type' - |> Array.filter (fun case -> case.GetFields().Length <> 0) - - let! fields = - [ - for case in notEmptyCases -> - translation { - let structName = case.Name - let tag = case.Tag - let! fields = - [ - for field in case.GetFields() -> - translate field.PropertyType >>= fun type' -> - State.return' { Name = field.Name; Type = type' } - ] - |> State.collect - - return tag, { Name = structName; Type = StructInplaceType(structName + "Type", fields) } - } - - ] - |> State.collect - - let duType = DiscriminatedUnionType(name, fields) - do! State.modify (fun context -> context.UnionDecls.Add(type', duType); context) - - return duType + let! context = State.get + + if context.CStructDecls.ContainsKey type' then + return context.CStructDecls.[type'] + else + let notEmptyCases = + FSharpType.GetUnionCases type' + |> Array.filter (fun case -> case.GetFields().Length <> 0) + + let! fields = + [ + for case in notEmptyCases -> + translation { + let structName = case.Name + let tag = case.Tag + let! fields = + [ + for field in case.GetFields() -> + translate field.PropertyType >>= fun type' -> + State.return' { Name = field.Name; Type = type' } + ] + |> State.collect + + return tag, { Name = structName; Type = StructInplaceType(structName + "Type", fields) } + } + ] + |> State.collect + + let! index = State.gets (fun ctx -> ctx.CStructDecls.Count) + let duType = DiscriminatedUnionType(sprintf "du%i" index, fields) + do! State.modify (fun context -> context.CStructDecls.Add(type', duType); context) + return duType :> StructType<_> } diff --git a/src/Brahma.FSharp.OpenCL.Translator/Utils/TypeReflection.fs b/src/Brahma.FSharp.OpenCL.Translator/Utils/TypeReflection.fs deleted file mode 100644 index f6347223..00000000 --- a/src/Brahma.FSharp.OpenCL.Translator/Utils/TypeReflection.fs +++ /dev/null @@ -1,75 +0,0 @@ -namespace Brahma.FSharp.OpenCL.Translator - -open FSharp.Reflection -open System.Reflection -open System.Collections.Generic -open Microsoft.FSharp.Quotations -open System - -module TypeReflection = - () - // let private hasAttribute<'attr> (tp: Type) = - // tp.GetCustomAttributes(false) - // |> Seq.tryFind (fun attr -> attr.GetType() = typeof<'attr>) - // |> Option.isSome - - // let collectTypes expr typePredicate (nestedTypes: Type -> Type[]) (escapeNames: string[]) = - // let types = HashSet() - - // let rec add (type': Type) = - // if - // typePredicate type' && - // not <| types.Contains type' && - // not <| Array.exists ((=) type'.Name) escapeNames - // then - // nestedTypes type' |> Array.iter add - // types.Add type' |> ignore - - // let rec go (expr: Expr) = - // add expr.Type - - // match expr with - // | ExprShape.ShapeVar _ -> () - // | ExprShape.ShapeLambda (_, body) -> go body - // | ExprShape.ShapeCombination (_, exprs) -> List.iter go exprs - - // go expr - // types |> List.ofSeq - - // let collectUserDefinedStructs expr = - // let isStruct = hasAttribute - // let escapeNames = [||] - - // let nestedTypes (type': Type) = - // seq { - // type'.GetProperties() - // |> Array.map (fun prop -> prop.PropertyType) - - // // dont needed i think - // if not <| FSharpType.IsRecord type' then - // type'.GetFields() - // |> Array.map (fun field -> field.FieldType) - // } - // |> Array.concat - - // collectTypes expr isStruct nestedTypes escapeNames - - // let collectTuples expr = - // let isTuple = FSharpType.IsTuple - // let escapeNames = [||] - - // let nestedTypes (type': Type) = FSharpType.GetTupleElements type' - - // collectTypes expr isTuple nestedTypes escapeNames - - // let collectDiscriminatedUnions expr = - // let unionPredicate = FSharpType.IsUnion - // let escapeNames = [||] - - // let nestedTypes (type': Type) = - // FSharpType.GetUnionCases type' - // |> Array.map (fun (case: UnionCaseInfo) -> case.GetFields()) - // |> Array.concat - // |> Array.map (fun (prop: PropertyInfo) -> prop.PropertyType) - - // collectTypes expr unionPredicate nestedTypes escapeNames diff --git a/src/Brahma.FSharp.OpenCL.Translator/Utils/Utils.fs b/src/Brahma.FSharp.OpenCL.Translator/Utils/Utils.fs new file mode 100644 index 00000000..8bd61dcb --- /dev/null +++ b/src/Brahma.FSharp.OpenCL.Translator/Utils/Utils.fs @@ -0,0 +1,19 @@ +namespace Brahma.FSharp.OpenCL.Translator + +open Microsoft.FSharp.Quotations +open System.Collections.Generic +open Brahma.FSharp.OpenCL.AST +open Microsoft.FSharp.Reflection +open System + +module Utils = + let getMethodInfoOfCall (expr: Expr) = + match expr with + | Patterns.Call (_, mInfo, _) -> mInfo + | DerivedPatterns.Lambdas (args, Patterns.Call (_, mInfo, _)) -> mInfo + | _ -> failwithf "Expression is not kind of call, but %O" expr + + let makeGenericMethodCall (types: System.Type list) (expr: Expr) = + (getMethodInfoOfCall expr) + .GetGenericMethodDefinition() + .MakeGenericMethod(Array.ofList types) diff --git a/src/YC.OpenCL.NET/Cl.API.cs b/src/YC.OpenCL.NET/Cl.API.cs index 3a567beb..ecfe7363 100644 --- a/src/YC.OpenCL.NET/Cl.API.cs +++ b/src/YC.OpenCL.NET/Cl.API.cs @@ -29,7 +29,7 @@ namespace OpenCL.Net public static partial class Cl { - public const string Library = "opencl.dll"; + public const string Library = "opencl"; static Cl() { @@ -39,17 +39,38 @@ static Cl() private static IntPtr ImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath) { IntPtr libHandle = IntPtr.Zero; + + var envOclPath = System.Environment.GetEnvironmentVariable("BRAHMA_OCL_PATH"); + if (libraryName == Library) { - if (OperatingSystem.IsLinux()) + if (NativeLibrary.TryLoad(Library, assembly, searchPath, out libHandle)) + { + return libHandle; + } + else if (NativeLibrary.TryLoad(envOclPath, assembly, searchPath, out libHandle)) + { + return libHandle; + } + + try { - libHandle = NativeLibrary.Load("/usr/lib/x86_64-linux-gnu/libOpenCL.so.1.0.0", assembly, searchPath); + if (OperatingSystem.IsLinux()) + { + libHandle = NativeLibrary.Load("/usr/lib/x86_64-linux-gnu/libOpenCL.so.1.0.0", assembly, searchPath); + } + else if (OperatingSystem.IsMacOS()) + { + libHandle = NativeLibrary.Load("/System/Library/Frameworks/OpenCL.framework/OpenCL", assembly, searchPath); + } } - else if (OperatingSystem.IsMacOS()) + catch (DllNotFoundException e) { - libHandle = NativeLibrary.Load("/System/Library/Frameworks/OpenCL.framework/OpenCL", assembly, searchPath); + Console.WriteLine(e); + Console.WriteLine("Set BRAHMA_OCL_PATH environment variable to OpenCL library path"); } } + return libHandle; } diff --git a/tests/Brahma.FSharp.Tests/AtomicTests.fs b/tests/Brahma.FSharp.Tests/AtomicTests.fs index 030882b1..2970c25b 100644 --- a/tests/Brahma.FSharp.Tests/AtomicTests.fs +++ b/tests/Brahma.FSharp.Tests/AtomicTests.fs @@ -12,6 +12,8 @@ open ExpectoFsCheck open FsCheck open Brahma.FSharp.Tests +// TODO add tests in inc dec on supported types (generate spinlock) + let logger = Log.create "AtomicTests" type NormalizedFloatArray = @@ -77,7 +79,8 @@ let stressTest<'a when 'a : equality and 'a : struct> (f: Expr<'a -> 'a>) size r let gid = range.GlobalID0 if gid < size then atomic %f result.[0] |> ignore - barrier () + + barrierLocal () @> let expected = @@ -127,7 +130,7 @@ let foldTest<'a when 'a : equality and 'a : struct> f (isEqual: 'a -> 'a -> bool if gid < arrayLength then atomic %f localResult.[0] array.[gid] |> ignore - barrier () + barrierLocal () if lid = 0 then atomic %f result.[0] localResult.[0] |> ignore @@ -247,9 +250,11 @@ let foldTestCases = testList "Fold tests" [ testCase "Reduce test atomic 'min' on int" <| fun () -> foldTest <@ min @> (=) ptestCase "Reduce test atomic 'min' on int64" <| fun () -> foldTest <@ min @> (=) + testCase "Reduce test atomic 'min' on int16" <| fun () -> foldTest <@ min @> (=) testCase "Reduce test atomic 'max' on int" <| fun () -> foldTest <@ max @> (=) ptestCase "Reduce test atomic 'max' on int64" <| fun () -> foldTest <@ max @> (=) + testCase "Reduce test atomic 'max' on int16" <| fun () -> foldTest <@ max @> (=) testCase "Reduce test atomic '&&&' on int" <| fun () -> foldTest <@ (&&&) @> (=) ptestCase "Reduce test atomic '&&&' on int64" <| fun () -> foldTest <@ (&&&) @> (=) @@ -282,7 +287,7 @@ let perfomanceTest = ptestCase "Perfomance test on 'inc'" <| fun () -> localAcc.[0] <- 0 atomic inc localAcc.[0] |> ignore - barrier () + barrierLocal () if range.LocalID0 = 0 then result.[0] <- localAcc.[0] @@ -298,7 +303,7 @@ let perfomanceTest = ptestCase "Perfomance test on 'inc'" <| fun () -> localAcc.[0] <- 0 atomic %inc localAcc.[0] |> ignore - barrier () + barrierLocal () if range.LocalID0 = 0 then result.[0] <- localAcc.[0] diff --git a/tests/Brahma.FSharp.Tests/Common.fs b/tests/Brahma.FSharp.Tests/Common.fs index a3b32427..c9ac957e 100644 --- a/tests/Brahma.FSharp.Tests/Common.fs +++ b/tests/Brahma.FSharp.Tests/Common.fs @@ -14,6 +14,25 @@ module Common = let platformName = ClPlatform.Any ClContext(platformName, deviceType) + let defaultInArrayLength = 4 + let intInArr = [| 0 .. defaultInArrayLength - 1 |] + let float32Arr = Array.init defaultInArrayLength float32 + let default1D = Range1D(defaultInArrayLength, 1) + let default2D = Range2D(defaultInArrayLength, 1) + + let checkResult command (inArr: 'a[]) (expectedArr: 'a[]) = + let actual = + opencl { + use! inBuf = ClArray.toDevice inArr + do! runCommand command <| fun x -> + x default1D inBuf + + return! ClArray.toHost inBuf + } + |> ClTask.runSync context + + Expect.sequenceEqual actual expectedArr "Arrays should be equals" + module CustomDatatypes = [] type WrappedInt = @@ -41,16 +60,13 @@ module Utils = Expect.equal all1 all2 "Files should be equals as strings" let openclCompile (command: Expr<('a -> 'b)>) = - let kernel = context.CreateClKernel command + let kernel = context.CreateClProgram command kernel.Code let openclTranslate (expr: Expr) = - let translator = FSQuotationToOpenCLTranslator() + let translator = FSQuotationToOpenCLTranslator(TranslatorOptions()) let (ast, _) = translator.Translate(expr) print ast let openclTransformQuotation (expr: Expr) = - QuotationTransformers.Transformer.transformQuotation expr [] - -module Generators = - () + QuotationTransformers.Transformer.transformQuotation expr diff --git a/tests/Brahma.FSharp.Tests/CompositeTypesTests.fs b/tests/Brahma.FSharp.Tests/CompositeTypesTests.fs index 80863d08..94440908 100644 --- a/tests/Brahma.FSharp.Tests/CompositeTypesTests.fs +++ b/tests/Brahma.FSharp.Tests/CompositeTypesTests.fs @@ -4,6 +4,7 @@ open Expecto open Brahma.FSharp.OpenCL open FSharp.Quotations open Brahma.FSharp.Tests +open FsCheck // Incomplete pattern matching in record deconstruction #nowarn "667" @@ -29,6 +30,18 @@ type GenericRecord<'a, 'b> = mutable Y: 'b } +[] +type StructOfIntInt64 = + val mutable X: int + val mutable Y: int64 + new(x, y) = { X = x; Y = y } + +[] +type GenericStruct<'a, 'b> = + val mutable X: 'a + val mutable Y: 'b + new(x, y) = { X = x; Y = y } + let check<'a when 'a : struct and 'a : equality> (data: 'a[]) (command: int -> Expr ClArray<'a> -> unit>) = let length = data.Length @@ -131,9 +144,94 @@ let recordTestCases = testList "Record tests" [ if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>) ] +let genGenericStruct<'a, 'b> = + gen { + let! x = Arb.generate<'a> + let! y = Arb.generate<'b> + + return GenericStruct(x, y) + } + +type GenericStructGenerator = + static member GenericStruct() = Arb.fromGen genGenericStruct + +let structTests = testList "Struct tests" [ + testCase "Smoke test" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + let b = buf.[0] + buf.[0] <- buf.[1] + buf.[1] <- b + @> + + checkResult command [|StructOfIntInt64(1, 2L); StructOfIntInt64(3, 4L)|] + [|StructOfIntInt64(3, 4L); StructOfIntInt64(1, 2L)|] + + testCase "Struct constructor test" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- StructOfIntInt64(5, 6L) + @> + + checkResult command [|StructOfIntInt64(1, 2L); StructOfIntInt64(3, 4L)|] + [|StructOfIntInt64(5, 6L); StructOfIntInt64(3, 4L)|] + + testCase "Struct prop set" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let mutable y = buf.[0] + y.X <- 5 + buf.[0] <- y + @> + + checkResult command [|StructOfIntInt64(1, 2L)|] [|StructOfIntInt64(5, 2L)|] + + testCase "Struct prop get" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + let mutable y = buf.[0] + y.X <- y.X + 3 + buf.[0] <- y + @> + + checkResult command [|StructOfIntInt64(1, 2L); StructOfIntInt64(3, 4L)|] + [|StructOfIntInt64(4, 2L); StructOfIntInt64(3, 4L)|] + + let inline command length = + <@ + fun (gid: int) (buffer: ClArray>) -> + if gid < length then + let tmp = buffer.[gid] + let x = tmp.X + let y = tmp.Y + let mutable innerStruct = GenericStruct(x, y) + innerStruct.X <- x + innerStruct.Y <- y + buffer.[gid] <- GenericStruct(innerStruct.X, innerStruct.Y) + @> + + let config = { FsCheckConfig.defaultConfig with arbitrary = [typeof] } + + testPropertyWithConfig config (message "GenericStruct") <| fun (data: GenericStruct[]) -> + if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>) + + testPropertyWithConfig config (message "GenericStruct<(int * int64), (bool * bool)>") <| fun (data: GenericStruct<(int * int64), (bool * bool)>[]) -> + if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>) + + testPropertyWithConfig config (message "GenericStruct") <| fun (data: GenericStruct[]) -> + if data.Length <> 0 then check data (fun length -> <@ fun (range: Range1D) (buffer: ClArray<_>) -> (%command length) range.GlobalID0 buffer @>) +] + let tests = testList "Tests on composite types" [ tupleTestCases recordTestCases + structTests ] |> testSequenced diff --git a/tests/Brahma.FSharp.Tests/Expected/Barrier.Full.cl b/tests/Brahma.FSharp.Tests/Expected/Barrier.Full.cl new file mode 100644 index 00000000..35901fd1 --- /dev/null +++ b/tests/Brahma.FSharp.Tests/Expected/Barrier.Full.cl @@ -0,0 +1,2 @@ +__kernel void brahmaKernel () +{barrier(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE) ;} diff --git a/tests/Brahma.FSharp.Tests/Expected/Barrier.Global.cl b/tests/Brahma.FSharp.Tests/Expected/Barrier.Global.cl new file mode 100644 index 00000000..3ad75c10 --- /dev/null +++ b/tests/Brahma.FSharp.Tests/Expected/Barrier.Global.cl @@ -0,0 +1,2 @@ +__kernel void brahmaKernel () +{barrier(CLK_GLOBAL_MEM_FENCE) ;} diff --git a/tests/Brahma.FSharp.Tests/Expected/Barrier.Local.cl b/tests/Brahma.FSharp.Tests/Expected/Barrier.Local.cl new file mode 100644 index 00000000..5528d6d9 --- /dev/null +++ b/tests/Brahma.FSharp.Tests/Expected/Barrier.Local.cl @@ -0,0 +1,2 @@ +__kernel void brahmaKernel () +{barrier(CLK_LOCAL_MEM_FENCE) ;} diff --git a/tests/Brahma.FSharp.Tests/Expected/MAX.Transformation.cl b/tests/Brahma.FSharp.Tests/Expected/MAX.Transformation.cl new file mode 100644 index 00000000..0e0e01ed --- /dev/null +++ b/tests/Brahma.FSharp.Tests/Expected/MAX.Transformation.cl @@ -0,0 +1,5 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel void brahmaKernel (__global double * buf) +{double tempVarY = 1 ; + buf [0] = max (buf [0], tempVarY) ; + buf [0] = max (buf [0], tempVarY) ;} diff --git a/tests/Brahma.FSharp.Tests/FullTests.fs b/tests/Brahma.FSharp.Tests/FullTests.fs index 110513b7..9277575b 100644 --- a/tests/Brahma.FSharp.Tests/FullTests.fs +++ b/tests/Brahma.FSharp.Tests/FullTests.fs @@ -9,32 +9,7 @@ open Expecto.Logging.Message let logger = Log.create "FullTests" -[] -type TestStruct = - val mutable x: int - val mutable y: float - new(x, y) = { x = x; y = y } - -let defaultInArrayLength = 4 -let intInArr = [| 0 .. defaultInArrayLength - 1 |] -let float32Arr = Array.init defaultInArrayLength float32 -let default1D = Range1D(defaultInArrayLength, 1) -let default2D = Range2D(defaultInArrayLength, 1) - -let checkResult command (inArr: 'a[]) (expectedArr: 'a[]) = - let actual = - opencl { - use! inBuf = ClArray.toDevice inArr - do! runCommand command <| fun x -> - x default1D inBuf - - return! ClArray.toHost inBuf - } - |> ClTask.runSync context - - Expect.sequenceEqual actual expectedArr "Arrays should be equals" - -let dataStructuresApiTests = testList "Check correctness of data structures api" [ +let smokeTestsOnPrimitiveTypes = testList "Simple tests on primitive types" [ testCase "Array item set" <| fun _ -> let command = <@ @@ -71,7 +46,7 @@ let dataStructuresApiTests = testList "Check correctness of data structures api" checkResult command [|0y; 1y; 2y; 3y|] [|1y; 1y; 2y; 3y|] - testCase "Array item set. Sequential operations." <| fun _ -> + testCase "Array item set. Sequential operations" <| fun _ -> let command = <@ fun (range: Range1D) (buf: ClArray) -> @@ -81,331 +56,181 @@ let dataStructuresApiTests = testList "Check correctness of data structures api" checkResult command intInArr [|2; 4; 2; 3|] - testCase "Getting value of 'int clcell' should be correct" <| fun () -> + testCase "Byte type support with overflow" <| fun _ -> let command = <@ - fun (range: Range1D) (buffer: int clarray) (cell: int clcell) -> - let gid = range.GlobalID0 - buffer.[gid] <- cell.Value + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + buf.[0] <- buf.[0] + 1uy + buf.[1] <- buf.[1] + 1uy + buf.[2] <- buf.[2] + 1uy @> - let value = 10 - let expected = Array.replicate defaultInArrayLength value - - let actual = - opencl { - use! cell = ClCell.toDevice 10 - use! buffer = ClArray.alloc defaultInArrayLength - do! runCommand command <| fun it -> - it - <| default1D - <| buffer - <| cell - - return! ClArray.toHost buffer - } - |> ClTask.runSync context - - "Arrays should be equal" - |> Expect.sequenceEqual actual expected - - // TODO test on getting Value property of non-clcell type - // TODO test on getting Item property on non-clarray type + checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] +] - testCase "Setting value of 'int clcell' should be correct" <| fun () -> - let value = 10 +let typeCastingTests = testList "Type castings tests" [ + testCase "uint64 -> int64" <| fun _ -> let command = <@ - fun (range: Range1D) (cell: int clcell) -> - cell.Value <- value + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- int64 1UL @> - let actual = - opencl { - use! cell = ClCell.toDevice value - do! runCommand command <| fun it -> - it - <| default1D - <| cell - - return! ClCell.toHost cell - } - |> ClTask.runSync context - - "Arrays should be equal" - |> Expect.equal actual value + checkResult command [|0L; 1L|] [|1L; 1L|] - testCase "Using 'int clcell' from inner function should work correctly" <| fun () -> - let value = 10 + testCase "int64 -> uint64" <| fun _ -> let command = <@ - fun (range: Range1D) (cell: int clcell) -> - let f () = - let x = cell.Value - cell.Value <- x - - f () + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- uint64 1L @> - let actual = - opencl { - use! cell = ClCell.toDevice value - do! runCommand command <| fun it -> - it - <| default1D - <| cell - - return! ClCell.toHost cell - } - |> ClTask.runSync context - - "Arrays should be equal" - |> Expect.equal actual value + checkResult command [|0UL; 1UL|] [|1UL; 1UL|] - testCase "Using 'int clcell' with native atomic operation should be correct" <| fun () -> - let value = 10 + testCase "byte -> float -> byte" <| fun _ -> let command = <@ - fun (range: Range1D) (cell: int clcell) -> - atomic (+) cell.Value value |> ignore + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + buf.[0] <- byte (float buf.[0]) + buf.[1] <- byte (float buf.[1]) + buf.[2] <- byte (float buf.[2]) @> - let expected = value * default1D.GlobalWorkSize - - let actual = - opencl { - use! cell = ClCell.toDevice 0 - do! runCommand command <| fun it -> - it - <| default1D - <| cell - - return! ClCell.toHost cell - } - |> ClTask.runSync context - - "Arrays should be equal" - |> Expect.equal actual expected + checkResult command [|0uy; 255uy; 254uy|] [|0uy; 255uy; 254uy|] - ptestCase "Using 'int clcell' with spinlock atomic operation should be correct" <| fun () -> - let value = 10 + // test fail on Intel platform: + // Actual: [1uy, 255uy, 255uy] + ptestCase "Byte and float 2" <| fun _ -> let command = <@ - fun (range: Range1D) (cell: int clcell) -> - atomic (fun x -> x + value) cell.Value |> ignore + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + buf.[0] <- byte ((float buf.[0]) + 1.0) + buf.[1] <- byte ((float buf.[1]) + 1.0) + buf.[2] <- byte ((float buf.[2]) + 1.0) @> - let expected = value * default1D.GlobalWorkSize - - let actual = - opencl { - use! cell = ClCell.toDevice 0 - do! runCommand command <| fun it -> - it - <| default1D - <| cell - - return! ClCell.toHost cell - } - |> ClTask.runSync context - - "Arrays should be equal" - |> Expect.equal actual expected -] - -let typeCastingTests = - testList "Type castings tests" - [ - testCase "Type casting. Long" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- (int64)1UL - @> - - checkResult command [|0L; 1L|] [|1L; 1L|] - - testCase "Type casting. Ulong" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- 1UL - @> - - checkResult command [|0UL; 1UL; 2UL; 3UL|] [|1UL; 1UL; 2UL; 3UL|] - - testCase "Type casting. ULong" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- (uint64)1L - @> - - checkResult command [|0UL; 1UL|] [|1UL; 1UL|] - - testCase "Byte type support" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 - then - buf.[0] <- buf.[0] + 1uy - buf.[1] <- buf.[1] + 1uy - buf.[2] <- buf.[2] + 1uy - @> + checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - - testCase "Byte and float32" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 - then - buf.[0] <- byte (float buf.[0]) - buf.[1] <- byte (float buf.[1]) - buf.[2] <- byte (float buf.[2]) - @> + // test failed on Intel platform: + // Actual : [1uy, 1uy, 1uy] + ptestCase "Byte and float in condition" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + let x = if true then buf.[0] + 1uy else buf.[0] + 1uy + buf.[0] <- x + let y = if true then buf.[1] + 1uy else buf.[1] + 1uy + buf.[1] <- y + let z = if true then buf.[2] + 1uy else buf.[2] + 1uy + buf.[2] <- z + @> - checkResult command [|0uy; 255uy; 254uy|] [|0uy; 255uy; 254uy|] + checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - // test fail on Intel platform: - // Actual: [1uy, 255uy, 255uy] - ptestCase "Byte and float 2" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 + // test failed on Intel platform due to exception + ptestCase "Byte and float in condition 2" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 + then + let x = + if true then - buf.[0] <- byte ((float buf.[0]) + 1.0) - buf.[1] <- byte ((float buf.[1]) + 1.0) - buf.[2] <- byte ((float buf.[2]) + 1.0) - @> - - checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - - // test failed on Intel platform: - // Actual : [1uy, 1uy, 1uy] - ptestCase "Byte and float in condition" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 + let g = 1uy + buf.[0] + g + else buf.[0] + 1uy + buf.[0] <- x + let y = + if true then - let x = if true then buf.[0] + 1uy else buf.[0] + 1uy - buf.[0] <- x - let y = if true then buf.[1] + 1uy else buf.[1] + 1uy - buf.[1] <- y - let z = if true then buf.[2] + 1uy else buf.[2] + 1uy - buf.[2] <- z - @> - - checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - - // test failed on Intel platform due to exception - ptestCase "Byte and float in condition 2" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 + let g = 1uy + buf.[1] + g + else buf.[1] + 1uy + buf.[1] <- y + let z = + if true then - let x = - if true - then - let g = 1uy - buf.[0] + g - else buf.[0] + 1uy - buf.[0] <- x - let y = - if true - then - let g = 1uy - buf.[1] + g - else buf.[1] + 1uy - buf.[1] <- y - let z = - if true - then - let g = 1uy - buf.[2] + g - else buf.[2] + 1uy - buf.[2] <- z - @> - - checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] - ] + let g = 1uy + buf.[2] + g + else buf.[2] + 1uy + buf.[2] <- z + @> + checkResult command [|0uy; 255uy; 254uy|] [|1uy; 0uy; 255uy|] +] -let bindingTests = - testList "Bindings tests" - [ - testCase "Bindings. Simple." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let x = 1 - buf.[0] <- x - @> +let bindingTests = testList "Bindings tests" [ + testCase "Bindings. Simple" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let x = 1 + buf.[0] <- x + @> - checkResult command intInArr [|1; 1; 2; 3|] + checkResult command intInArr [|1; 1; 2; 3|] - testCase "Bindings. Sequential bindings." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let x = 1 - let y = x + 1 - buf.[0] <- y - @> + testCase "Bindings. Sequential bindings" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let x = 1 + let y = x + 1 + buf.[0] <- y + @> - checkResult command intInArr [|2; 1; 2; 3|] + checkResult command intInArr [|2; 1; 2; 3|] - testCase "Bindings. Binding in IF." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if 2 = 0 - then - let x = 1 - buf.[0] <- x - else - let i = 2 - buf.[0] <- i - @> + testCase "Bindings. Binding in IF" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if 2 = 0 then + let x = 1 + buf.[0] <- x + else + let i = 2 + buf.[0] <- i + @> - checkResult command intInArr [|2; 1; 2; 3|] + checkResult command intInArr [|2; 1; 2; 3|] - testCase "Bindings. Binding in FOR." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - for i in 0..3 do - let x = i * i - buf.[i] <- x - @> + testCase "Bindings. Binding in FOR" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + for i in 0..3 do + let x = i * i + buf.[i] <- x + @> - checkResult command intInArr [|0; 1; 4; 9|] + checkResult command intInArr [|0; 1; 4; 9|] - testCase "Bindings. Binding in WHILE." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - while buf.[0] < 5 do - let x = buf.[0] + 1 - buf.[0] <- x * x - @> + testCase "Bindings. Binding in WHILE" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + while buf.[0] < 5 do + let x = buf.[0] + 1 + buf.[0] <- x * x + @> - checkResult command intInArr [|25; 1; 2; 3|] - ] + checkResult command intInArr [|25; 1; 2; 3|] +] let operatorsAndMathFunctionsTests = - let testOpGen testCase - (name: string) + let binaryOpTestGen testCase name (binop: Expr<'a -> 'a -> 'a>) (xs: array<'a>) (ys: array<'a>) (expected: array<'a>) = - testCase name <| fun _ -> + + testCase name <| fun () -> let command = <@ fun (range: Range1D) (xs: ClArray<'a>) (ys: ClArray<'a>) (zs: ClArray<'a>) -> @@ -431,824 +256,820 @@ let operatorsAndMathFunctionsTests = Expect.sequenceEqual actual expected ":(" - testList "Operators and math functions tests" - [ - testOpGen testCase "Boolean or 1." <@ (||) @> - [|true; false; false; false|] - [|false; true; true; true|] - [|true; true; true; true|] - - testOpGen testCase "Boolean or 2." <@ (||) @> - [|true; false|] - [|false; true|] - [|true; true|] - - testOpGen testCase "Boolean and 1." <@ (&&) @> - [|true; false; false; false|] - [|true; false; true; true|] - [|true; false; false; false|] - - testOpGen testCase "Binop plus 1." <@ (+) @> - [|1; 2; 3; 4|] - [|5; 6; 7; 8|] - [|6; 8; 10; 12|] - - // Failed: due to precision - ptestCase "Math sin." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let i = range.GlobalID0 - buf.[i] <- System.Math.Sin (float buf.[i]) - @> - - let inA = [|0.0; 1.0; 2.0; 3.0|] - checkResult command inA (inA |> Array.map System.Math.Sin) //[|0.0; 0.841471; 0.9092974; 0.14112|] - ] + let unaryOpTestGen testCase name + (unop: Expr<'a -> 'a>) + (xs: array<'a>) + (expected: array<'a>) = -let pipeTests = - testList "Pipe tests" [ - // Lambda is not supported. - ptestCase "Forward pipe." <| fun _ -> + testCase name <| fun () -> let command = <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- (1.25f |> int) + fun (range: Range1D) (xs: ClArray<'a>) (zs: ClArray<'a>) -> + let i = range.GlobalID0 + zs.[i] <- (%unop) xs.[i] @> - checkResult command intInArr [|1; 1; 2; 3|] - // Lambda is not supported. - ptestCase "Backward pipe." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- int <| 1.25f + 2.34f - @> - checkResult command intInArr [|3; 1; 2; 3|] + let range = Range1D <| Array.length expected + let zs = Array.zeroCreate <| Array.length expected - ptestCase "Check simple '|> ignore'" <| fun () -> - let command = - <@ - fun (range: Range1D) (buffer: ClArray) -> - let gid = range.GlobalID0 - atomic inc buffer.[gid] |> ignore - @> + let actual = + opencl { + use! inBufXs = ClArray.toDevice xs + use! outBuf = ClArray.toDevice zs - checkResult command intInArr (intInArr |> Array.map ((+) 1)) -] + do! runCommand command <| fun x -> + x range inBufXs outBuf -let controlFlowTests = - testList "Control flow tests" [ - testCase "Control flow. If Then." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if 0 = 2 then buf.[0] <- 42 - @> + return! ClArray.toHost outBuf + } + |> ClTask.runSync context - checkResult command intInArr [|0; 1; 2; 3|] + Expect.sequenceEqual actual expected ":(" - testCase "Control flow. If Then Else." <| fun _ -> + testList "Operators and math functions tests" [ + binaryOpTestGen testCase "Boolean OR" <@ (||) @> + [|true; false; false; true|] + [|false; true; false; true|] + [|true; true; false; true|] + + binaryOpTestGen testCase "Boolean AND" <@ (&&) @> + [|true; false; false; true|] + [|false; true; false; true|] + [|false; false; false; true|] + + binaryOpTestGen testCase "Bitwise OR on int" <@ (|||) @> + [|1; 0; 0; 1|] + [|0; 1; 0; 1|] + [|1; 1; 0; 1|] + + binaryOpTestGen testCase "Bitwise AND on int" <@ (&&&) @> + [|1; 0; 0; 1|] + [|0; 1; 0; 1|] + [|0; 0; 0; 1|] + + binaryOpTestGen testCase "Bitwise XOR on int" <@ (^^^) @> + [|1; 0; 0; 1|] + [|0; 1; 0; 1|] + [|1; 1; 0; 0|] + + binaryOpTestGen testCase "Arithmetic PLUS on int" <@ (+) @> + [|1; 2; 3; 4|] + [|5; 6; 7; 8|] + [|6; 8; 10; 12|] + + unaryOpTestGen testCase "Bitwise NEGATION on int" <@ (~~~) @> + <|| ( + [|1; 10; 99; 0|] + |> fun array -> array, array |> Array.map (fun x -> - x - 1) + ) + + binaryOpTestGen testCase "MAX on float32" <@ max @> + [|1.f; 2.f; 3.f; 4.f|] + [|5.f; 6.f; 7.f; 8.f|] + [|5.f; 6.f; 7.f; 8.f|] + + binaryOpTestGen testCase "MIN on float32" <@ min @> + [|1.f; 2.f; 3.f; 4.f|] + [|5.f; 6.f; 7.f; 8.f|] + [|1.f; 2.f; 3.f; 4.f|] + + ptestCase "MAX on int16 with const" <| fun () -> let command = <@ - fun (range: Range1D) (buf: ClArray) -> - if 0 = 2 then buf.[0] <- 1 else buf.[0] <- 2 + fun (range: Range1D) (buf: int16 clarray) -> + let gid = range.GlobalID0 + buf.[gid] <- max buf.[gid] 1s @> - checkResult command intInArr [|2; 1; 2; 3|] + let inA = [|0s; 1s; 2s; 3s|] + checkResult command inA (Array.map (max 1s) inA) - testCase "Control flow. For Integer Loop." <| fun _ -> + // Failed: due to precision + ptestCase "Math sin" <| fun _ -> let command = <@ - fun (range: Range1D) (buf: ClArray) -> - for i in 1..3 do - buf.[i] <- 0 + fun (range: Range1D) (buf: ClArray) -> + let i = range.GlobalID0 + buf.[i] <- System.Math.Sin (float buf.[i]) @> - checkResult command intInArr [|0; 0; 0; 0|] - - testCase "Control flow. WHILE loop simple test." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - while buf.[0] < 5 do - buf.[0] <- buf.[0] + 1 - @> + let inA = [|0.0; 1.0; 2.0; 3.0|] + checkResult command inA (inA |> Array.map System.Math.Sin) //[|0.0; 0.841471; 0.9092974; 0.14112|] + ] - checkResult command intInArr [|5; 1; 2; 3|] +let controlFlowTests = testList "Control flow tests" [ + testCase "Check 'if then' condition" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if 0 = 2 then buf.[0] <- 42 + @> - testCase "Control flow. WHILE in FOR." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - for i in 0..3 do - while buf.[i] < 10 do - buf.[i] <- buf.[i] * buf.[i] + 1 - @> + checkResult command intInArr [|0; 1; 2; 3|] - checkResult command intInArr [|26; 26; 26; 10|] -] + testCase "Check 'if then else' condition" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if 0 = 2 then buf.[0] <- 1 else buf.[0] <- 2 + @> -let kernelArgumentsTests = - testList "Kernel arguments tests" [ - testCase "Kernel arguments. Simple 1D." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let i = range.GlobalID0 - buf.[i] <- i + i - @> + checkResult command intInArr [|2; 1; 2; 3|] - checkResult command intInArr [|0;2;4;6|] + testCase "Check 'for' integer loop" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + for i in 0 .. 3 do + buf.[i] <- i + @> - testCase "Kernel arguments. Simple 1D with copy." <| fun _ -> - let command = - <@ - fun (range: Range1D) (inBuf:ClArray) (outBuf:ClArray) -> - let i = range.GlobalID0 - outBuf.[i] <- inBuf.[i] - @> + checkResult command intInArr [|0; 1; 2; 3|] - let expected = [|0; 1; 2; 3|] + testCase "Check 'for' integer loop with step" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + for i in 0 .. 2 .. 6 do + buf.[i / 2] <- i + @> - let actual = - opencl { - use! inBuf = ClArray.toDevice intInArr - use! outBuf = ClArray.toDevice [|0; 0; 0; 0|] - do! runCommand command <| fun x -> - x default1D inBuf outBuf + checkResult command intInArr [|0; 2; 4; 6|] - return! ClArray.toHost inBuf - } - |> ClTask.runSync context + testCase "Check 'for' non-integer loop" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 then + for i in 0u .. 3u do + buf.[int i] <- i + @> - Expect.sequenceEqual actual expected "Arrays should be equals" + checkResult command [|0u; 0u; 0u; 0u|] [|0u; 1u; 2u; 3u|] - testCase "Kernel arguments. Simple 1D float." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let i = range.GlobalID0 - buf.[i] <- buf.[i] * buf.[i] - @> + testCase "Check simple 'while' loop" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + while buf.[0] < 5 do + buf.[0] <- buf.[0] + 1 + @> - checkResult command float32Arr [|0.0f; 1.0f; 4.0f; 9.0f|] + checkResult command intInArr [|5; 1; 2; 3|] - testCase "Kernel arguments. Int as arg." <| fun _ -> - let command = - <@ - fun (range: Range1D) x (buf: ClArray) -> - let i = range.GlobalID0 - buf.[i] <- x + x - @> + testCase "Check 'while' loop inside 'for' integer loop" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + for i in 0 .. 3 do + while buf.[i] < 10 do + buf.[i] <- buf.[i] * buf.[i] + 1 + @> - let expected = [|4; 4; 4; 4|] + checkResult command intInArr [|26; 26; 26; 10|] +] - let actual = - opencl { - use! inBuf = ClArray.toDevice intInArr - do! runCommand command <| fun x -> - x default1D 2 inBuf +let kernelArgumentsTests = testList "Kernel arguments tests" [ + testCase "Kernel arguments. Simple 1D" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let i = range.GlobalID0 + buf.[i] <- i + i + @> - return! ClArray.toHost inBuf - } - |> ClTask.runSync context + checkResult command intInArr [|0;2;4;6|] - Expect.sequenceEqual actual expected "Arrays should be equals" + testCase "Kernel arguments. Simple 1D with copy" <| fun _ -> + let command = + <@ + fun (range: Range1D) (inBuf:ClArray) (outBuf:ClArray) -> + let i = range.GlobalID0 + outBuf.[i] <- inBuf.[i] + @> - testCase "Kernel arguments. Sequential commands over single buffer." <| fun _ -> - let command = - <@ - fun (range: Range1D) i x (buf: ClArray) -> - buf.[i] <- x + x - @> + let expected = [|0; 1; 2; 3|] - let expected = [|4; 1; 4; 3|] + let actual = + opencl { + use! inBuf = ClArray.toDevice intInArr + use! outBuf = ClArray.toDevice [|0; 0; 0; 0|] + do! runCommand command <| fun x -> + x default1D inBuf outBuf - let actual = - opencl { - let! ctx = ClTask.ask - let kernel = (ctx.CreateClKernel command).GetNewKernel() + return! ClArray.toHost inBuf + } + |> ClTask.runSync context - let inArr = ctx.CreateClArray(intInArr) + Expect.sequenceEqual actual expected "Arrays should be equals" - ctx.CommandQueue.Post(Msg.MsgSetArguments(fun () -> kernel.ArgumentsSetter default1D 0 2 inArr)) - ctx.CommandQueue.Post(Msg.CreateRunMsg<_,_>(kernel)) + testCase "Kernel arguments. Simple 1D float" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let i = range.GlobalID0 + buf.[i] <- buf.[i] * buf.[i] + @> - ctx.CommandQueue.Post(Msg.MsgSetArguments(fun () -> kernel.ArgumentsSetter default1D 2 2 inArr)) - ctx.CommandQueue.Post(Msg.CreateRunMsg<_,_>(kernel)) + checkResult command float32Arr [|0.0f; 1.0f; 4.0f; 9.0f|] - let localOut = Array.zeroCreate intInArr.Length - let res = ctx.CommandQueue.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(inArr, localOut, ch)) - ctx.CommandQueue.Post <| Msg.CreateFreeMsg(inArr) + testCase "Kernel arguments. Int as arg" <| fun _ -> + let command = + <@ + fun (range: Range1D) x (buf: ClArray) -> + let i = range.GlobalID0 + buf.[i] <- x + x + @> - return res - } - |> ClTask.runSync context + let expected = [|4; 4; 4; 4|] - Expect.sequenceEqual actual expected "Arrays should be equals" - - testProperty "Parallel execution of kernel" <| fun _const -> - let n = 4 - let l = 256 - let getAllocator (context:ClContext) = - let kernel = - <@ - fun (r: Range1D) (buffer: ClArray) -> - let i = r.GlobalID0 - buffer.[i] <- _const - @> - let k = context.CreateClKernel kernel - fun (q:MailboxProcessor<_>) -> - let buf = context.CreateClArray(l, allocationMode = AllocationMode.AllocHostPtr) - let executable = k.GetNewKernel() - q.Post(Msg.MsgSetArguments(fun () -> executable.ArgumentsSetter (Range1D(l, l)) buf)) - q.Post(Msg.CreateRunMsg<_,_>(executable)) - buf - - let allocator = getAllocator context - let allocOnGPU (q:MailboxProcessor<_>) allocator = - let b = allocator q - let res = Array.zeroCreate l - q.PostAndReply (fun ch -> Msg.CreateToHostMsg(b, res, ch)) - q.Post (Msg.CreateFreeMsg b) - res - - - let actual = - Array.init n (fun _ -> - let q = context.CommandQueue - q.Error.Add (fun e -> printfn "%A" e) - q) - |> Array.mapi (fun i q -> async {return allocOnGPU q allocator}) - |> Async.Parallel - |> Async.RunSynchronously - - let expected = Array.init n (fun _ -> Array.create l _const) - - Expect.sequenceEqual actual expected "Arrays should be equals" - ] + let actual = + opencl { + use! inBuf = ClArray.toDevice intInArr + do! runCommand command <| fun x -> + x default1D 2 inBuf -let quotationInjectionTests = - testList "Quotation injection tests" [ - testCase "Quotations injections. Quotations injections 1." <| fun _ -> - let myF = <@ fun x -> x * x @> + return! ClArray.toHost inBuf + } + |> ClTask.runSync context - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- (%myF) 2 - buf.[1] <- (%myF) 4 - @> + Expect.sequenceEqual actual expected "Arrays should be equals" - checkResult command intInArr [|4;16;2;3|] + testCase "Kernel arguments. Sequential commands over single buffer" <| fun _ -> + let command = + <@ + fun (range: Range1D) i x (buf: ClArray) -> + buf.[i] <- x + x + @> - testCase "Quotations injections. Quotations injections 2." <| fun _ -> - let myF = <@ fun x y -> y - x @> + let expected = [|4; 1; 4; 3|] - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- (%myF) 2 5 - buf.[1] <- (%myF) 4 9 - @> + let actual = + opencl { + let! ctx = ClTask.ask + let kernel = ctx.CreateClProgram(command).GetKernel() - checkResult command intInArr [|3;5;2;3|] - ] + let inArr = ctx.CreateClArray(intInArr) -let localMemTests = - testList "Local memory tests" [ - // TODO: pointers to local data must be local too. - testCase "Local int. Work item counting" <| fun _ -> - let command = - <@ - fun (range: Range1D) (output: ClArray) -> - let globalID = range.GlobalID0 - let mutable x = local () + ctx.CommandQueue.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc default1D 0 2 inArr)) + ctx.CommandQueue.Post(Msg.CreateRunMsg<_,_>(kernel)) - if globalID = 0 then x <- 0 - barrier () + ctx.CommandQueue.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc default1D 2 2 inArr)) + ctx.CommandQueue.Post(Msg.CreateRunMsg<_,_>(kernel)) - atomic (+) x 1 |> ignore - // fetch local value before read, dont work withour barrier - barrier () + let localOut = Array.zeroCreate intInArr.Length + let res = ctx.CommandQueue.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(inArr, localOut, ch)) + ctx.CommandQueue.Post <| Msg.CreateFreeMsg(inArr) - if globalID = 0 then - output.[0] <- x + return res + } + |> ClTask.runSync context + + Expect.sequenceEqual actual expected "Arrays should be equals" + + testProperty "Parallel execution of kernel" <| fun _const -> + let n = 4 + let l = 256 + let getAllocator (context: ClContext) = + let kernel = + <@ + fun (r: Range1D) (buffer: ClArray) -> + let i = r.GlobalID0 + buffer.[i] <- _const @> + let k = context.CreateClProgram kernel + fun (q:MailboxProcessor<_>) -> + let buf = context.CreateClArray(l, allocationMode = AllocationMode.AllocHostPtr) + let executable = k.GetKernel() + q.Post(Msg.MsgSetArguments(fun () -> executable.KernelFunc (Range1D(l, l)) buf)) + q.Post(Msg.CreateRunMsg<_,_>(executable)) + buf + + let allocator = getAllocator context + let allocOnGPU (q:MailboxProcessor<_>) allocator = + let b = allocator q + let res = Array.zeroCreate l + q.PostAndReply (fun ch -> Msg.CreateToHostMsg(b, res, ch)) |> ignore + q.Post (Msg.CreateFreeMsg b) + res - let expected = [|5|] + let actual = + Array.init n (fun _ -> context.WithNewCommandQueue().CommandQueue) + |> Array.map (fun q -> async { return allocOnGPU q allocator }) + |> Async.Parallel + |> Async.RunSynchronously - let actual = - opencl { - use! inBuf = ClArray.toDevice [|0|] - do! runCommand command <| fun x -> - x (Range1D(5, 5)) inBuf + let expected = Array.init n (fun _ -> Array.create l _const) - return! ClArray.toHost inBuf - } - |> ClTask.runSync context + Expect.sequenceEqual actual expected "Arrays should be equals" +] - Expect.sequenceEqual actual expected "Arrays should be equals" +let quotationInjectionTests = testList "Quotation injection tests" [ + testCase "Quotations injections. Quotations injections 1" <| fun _ -> + let myF = <@ fun x -> x * x @> - testCase "Local array. Test 1" <| fun _ -> - let localWorkSize = 5 - let globalWorkSize = 15 + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- (%myF) 2 + buf.[1] <- (%myF) 4 + @> - let command = - <@ - fun (range: Range1D) (input: ClArray) (output: ClArray) -> - let localBuf = localArray localWorkSize + checkResult command intInArr [|4;16;2;3|] - localBuf.[range.LocalID0] <- range.LocalID0 - barrier() - output.[range.GlobalID0] <- localBuf.[(range.LocalID0 + 1) % localWorkSize] - @> + testCase "Quotations injections. Quotations injections 2" <| fun _ -> + let myF = <@ fun x y -> y - x @> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- (%myF) 2 5 + buf.[1] <- (%myF) 4 9 + @> - let expected = - [| for x in 1..localWorkSize -> x % localWorkSize |] - |> Array.replicate (globalWorkSize / localWorkSize) - |> Array.concat + checkResult command intInArr [|3;5;2;3|] +] - let actual = - opencl { - use! inBuf = ClArray.toDevice (Array.zeroCreate globalWorkSize) - use! outBuf = ClArray.toDevice (Array.zeroCreate globalWorkSize) - do! runCommand command <| fun x -> - x (Range1D(globalWorkSize, localWorkSize)) inBuf outBuf +let localMemTests = testList "Local memory tests" [ + // TODO: pointers to local data must be local too. + testCase "Local int. Work item counting" <| fun _ -> + let command = + <@ + fun (range: Range1D) (output: ClArray) -> + let globalID = range.GlobalID0 + let mutable x = local () - return! ClArray.toHost outBuf - } - |> ClTask.runSync context + if globalID = 0 then x <- 0 + barrierLocal () - Expect.sequenceEqual actual expected "Arrays should be equals" + atomic (+) x 1 |> ignore + // fetch local value before read, dont work without barrier + barrierLocal () - ptestCase "Local array. Test 2" <| fun _ -> - let command = - <@ fun (range: Range1D) (buf: ClArray) -> - let localBuf = localArray 42 - atomic xchg localBuf.[0] 1L |> ignore - buf.[0] <- localBuf.[0] - @> + if globalID = 0 then + output.[0] <- x + @> - checkResult command [|0L; 1L; 2L; 3L|] [|1L; 1L; 2L; 3L|] - ] + let expected = [|5|] -let letTransformationTests = - testList "Let Transformation Tests" [ - testCase "Template Let Transformation Test 0" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f = 3 - buf.[0] <- f - @> + let actual = + opencl { + use! inBuf = ClArray.toDevice [|0|] + do! runCommand command <| fun x -> + x (Range1D(5, 5)) inBuf - checkResult command intInArr [|3; 1; 2; 3|] + return! ClArray.toHost inBuf + } + |> ClTask.runSync context - testCase "Template Let Transformation Test 1" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let x = 4 - let f = - let x = 3 - x - buf.[0] <- x + f - @> - checkResult command intInArr [|7; 1; 2; 3|] + Expect.sequenceEqual actual expected "Arrays should be equals" - testCase "Template Let Transformation Test 1.2" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - let x c b = b + c + 4 + y - x 2 3 - buf.[0] <- f 1 - @> + testCase "Local array. Test 1" <| fun _ -> + let localWorkSize = 5 + let globalWorkSize = 15 - checkResult command intInArr [|10; 1; 2; 3|] + let command = + <@ + fun (range: Range1D) (input: ClArray) (output: ClArray) -> + let localBuf = localArray localWorkSize - testCase "Template Let Transformation Test 2" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f = - let x = - let y = 3 - y - x - buf.[0] <- f - @> + localBuf.[range.LocalID0] <- range.LocalID0 + barrierLocal () + output.[range.GlobalID0] <- localBuf.[(range.LocalID0 + 1) % localWorkSize] + @> - checkResult command intInArr [|3; 1; 2; 3|] - testCase "Template Let Transformation Test 3" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f = - let f = 5 - f - buf.[0] <- f - @> + let expected = + [| for x in 1..localWorkSize -> x % localWorkSize |] + |> Array.replicate (globalWorkSize / localWorkSize) + |> Array.concat - checkResult command intInArr [|5; 1; 2; 3|] + let actual = + opencl { + use! inBuf = ClArray.toDevice (Array.zeroCreate globalWorkSize) + use! outBuf = ClArray.toDevice (Array.zeroCreate globalWorkSize) + do! runCommand command <| fun x -> + x (Range1D(globalWorkSize, localWorkSize)) inBuf outBuf - testCase "Template Let Transformation Test 4" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f = - let f = - let f = 5 - f - f - buf.[0] <- f - @> + return! ClArray.toHost outBuf + } + |> ClTask.runSync context - checkResult command intInArr [|5; 1; 2; 3|] + Expect.sequenceEqual actual expected "Arrays should be equals" - testCase "Template Let Transformation Test 5" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f a b = - let x y z = y + z - x a b - buf.[0] <- f 1 7 - @> + ptestCase "Local array. Test 2" <| fun _ -> + let command = + <@ fun (range: Range1D) (buf: ClArray) -> + let localBuf = localArray 42 + atomic xchg localBuf.[0] 1L |> ignore + buf.[0] <- localBuf.[0] + @> - checkResult command intInArr [|8; 1; 2; 3|] + checkResult command [|0L; 1L; 2L; 3L|] [|1L; 1L; 2L; 3L|] +] - testCase "Template Let Transformation Test 6" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f x y = - let x = x - x + y - buf.[0] <- f 7 8 - @> +let letTransformationTests = testList "Let Transformation Tests" [ + testCase "Template Let Transformation Test 0" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f = 3 + buf.[0] <- f + @> - checkResult command intInArr [|15; 1; 2; 3|] + checkResult command intInArr [|3; 1; 2; 3|] - testCase "Template Let Transformation Test 7" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - let x y = 6 - y - x y - buf.[0] <- f 7 - @> + testCase "Template Let Transformation Test 1" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let x = 4 + let f = + let x = 3 + x + buf.[0] <- x + f + @> + checkResult command intInArr [|7; 1; 2; 3|] - checkResult command intInArr [|-1; 1; 2; 3|] + testCase "Template Let Transformation Test 1.2" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + let x c b = b + c + 4 + y + x 2 3 + buf.[0] <- f 1 + @> - testCase "Template Let Transformation Test 8" <| fun _ -> - let command = - <@ - fun (range: Range1D) (m: ClArray) -> - let p = m.[0] - let x n = - let l = m.[3] - let g k = k + m.[0] + m.[1] - let r = - let y a = - let x = 5 - n + (g 4) - let z t = m.[2] + a - t - z (a + x + l) - y 6 - r + m.[3] - if range.GlobalID0 = 0 - then m.[0] <- x 7 - @> + checkResult command intInArr [|10; 1; 2; 3|] - checkResult command intInArr [|-1; 1; 2; 3|] + testCase "Template Let Transformation Test 2" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f = + let x = + let y = 3 + y + x + buf.[0] <- f + @> - testCase "Template Let Transformation Test 9" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let x n = - let r = 8 - let h = r + n - h - buf.[0] <- x 9 - @> + checkResult command intInArr [|3; 1; 2; 3|] - checkResult command intInArr [|17; 1; 2; 3|] + testCase "Template Let Transformation Test 3" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f = + let f = 5 + f + buf.[0] <- f + @> - testCase "Template Let Transformation Test 10" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let p = 9 - let x n b = - let t = 0 - n + b + t - buf.[0] <- x 7 9 - @> + checkResult command intInArr [|5; 1; 2; 3|] - checkResult command intInArr [|16; 1; 2; 3|] + testCase "Template Let Transformation Test 4" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f = + let f = + let f = 5 + f + f + buf.[0] <- f + @> - testCase "Template Let Transformation Test 11" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let p = 1 - let m = - let r (l:int) = l - r 9 - let z (k:int) = k - buf.[0] <- m - @> + checkResult command intInArr [|5; 1; 2; 3|] - checkResult command intInArr [|9; 1; 2; 3|] + testCase "Template Let Transformation Test 5" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f a b = + let x y z = y + z + x a b + buf.[0] <- f 1 7 + @> - testCase "Template Let Transformation Test 12" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f x y = - let y = y - let y = y - let g x m = m + x - g x y - buf.[0] <- f 1 7 - @> + checkResult command intInArr [|8; 1; 2; 3|] - checkResult command intInArr [|8; 1; 2; 3|] + testCase "Template Let Transformation Test 6" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f x y = + let x = x + x + y + buf.[0] <- f 7 8 + @> - testCase "Template Let Transformation Test 13" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - let y = y - let y = y - let g (m:int) = m - g y - buf.[0] <- f 7 - @> + checkResult command intInArr [|15; 1; 2; 3|] - checkResult command intInArr [|7; 1; 2; 3|] + testCase "Template Let Transformation Test 7" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + let x y = 6 - y + x y + buf.[0] <- f 7 + @> - testCase "Template Let Transformation Test 14" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - let y = y - let y = y - let g (m:int) = - let g r t = r + y - t - let n o = o - (g y 2) - n 5 - g y - let z y = y - 2 - buf.[0] <- f (z 7) - @> + checkResult command intInArr [|-1; 1; 2; 3|] - checkResult command intInArr [|-3; 1; 2; 3|] + testCase "Template Let Transformation Test 8" <| fun _ -> + let command = + <@ + fun (range: Range1D) (m: ClArray) -> + let p = m.[0] + let x n = + let l = m.[3] + let g k = k + m.[0] + m.[1] + let r = + let y a = + let x = 5 - n + (g 4) + let z t = m.[2] + a - t + z (a + x + l) + y 6 + r + m.[3] + if range.GlobalID0 = 0 + then m.[0] <- x 7 + @> - testCase "Template Let Transformation Test 15" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - let Argi index = - if index = 0 - then buf.[1] - else buf.[2] - Argi y - buf.[0] <- f 0 - @> + checkResult command intInArr [|-1; 1; 2; 3|] - checkResult command intInArr [|1; 1; 2; 3|] + testCase "Template Let Transformation Test 9" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let x n = + let r = 8 + let h = r + n + h + buf.[0] <- x 9 + @> - testCase "Template Let Transformation Test 16" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let f y = - if y = 0 - then - let z (a:int) = a - z 9 - else buf.[2] - buf.[0] <- f 0 - @> + checkResult command intInArr [|17; 1; 2; 3|] - checkResult command intInArr [|9; 1; 2; 3|] + testCase "Template Let Transformation Test 10" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let p = 9 + let x n b = + let t = 0 + n + b + t + buf.[0] <- x 7 9 + @> - testCase "Template Let Transformation Test 17" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 - then - let f y = - let g = buf.[1] + 1 - y + g - for i in 0..3 do - buf.[i] <- f i - @> + checkResult command intInArr [|16; 1; 2; 3|] - checkResult command intInArr [|2; 3; 6; 7|] + testCase "Template Let Transformation Test 11" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let p = 1 + let m = + let r (l:int) = l + r 9 + let z (k:int) = k + buf.[0] <- m + @> - testCase "Template Let Transformation Test 18" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - for i in 0..3 do - let f = - let g = buf.[1] + 1 - i + g - if range.GlobalID0 = 0 - then buf.[i] <- f - @> + checkResult command intInArr [|9; 1; 2; 3|] - checkResult command intInArr [|2; 3; 6; 7|] + testCase "Template Let Transformation Test 12" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f x y = + let y = y + let y = y + let g x m = m + x + g x y + buf.[0] <- f 1 7 + @> - testCase "Template Let Transformation Test 19" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 - then - for i in 0..3 do - let f x = - let g = buf.[1] + x - i + g - buf.[i] <- f 1 - @> + checkResult command intInArr [|8; 1; 2; 3|] + + testCase "Template Let Transformation Test 13" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + let y = y + let y = y + let g (m:int) = m + g y + buf.[0] <- f 7 + @> + + checkResult command intInArr [|7; 1; 2; 3|] + + testCase "Template Let Transformation Test 14" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + let y = y + let y = y + let g (m:int) = + let g r t = r + y - t + let n o = o - (g y 2) + n 5 + g y + let z y = y - 2 + buf.[0] <- f (z 7) + @> - checkResult command intInArr [|2; 3; 6; 7|] + checkResult command intInArr [|-3; 1; 2; 3|] - // TODO: perform range (1D, 2D, 3D) erasure when range is lifted. - ptestCase "Template Let Transformation Test 20" <| fun _ -> - let command = - <@ - fun (range: Range1D) (m: ClArray) -> - let f x = - range.GlobalID0 + x - m.[0] <- f 2 - @> + testCase "Template Let Transformation Test 15" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + let Argi index = + if index = 0 + then buf.[1] + else buf.[2] + Argi y + buf.[0] <- f 0 + @> - checkResult command intInArr [|2; 3; 6; 7|] - ] + checkResult command intInArr [|1; 1; 2; 3|] -let letQuotationTransformerSystemTests = - testList "Let Transformation Tests Mutable Vars" [ - testCase "Test 0" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let mutable x = 1 - let f y = - x <- y - f 10 - buf.[0] <- x - @> + testCase "Template Let Transformation Test 16" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let f y = + if y = 0 + then + let z (a:int) = a + z 9 + else buf.[2] + buf.[0] <- f 0 + @> - checkResult command intInArr [|10; 1; 2; 3|] + checkResult command intInArr [|9; 1; 2; 3|] - testCase "Test 1" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let mutable x = 1 + testCase "Template Let Transformation Test 17" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 + then let f y = - x <- x + y - f 10 - buf.[0] <- x - @> + let g = buf.[1] + 1 + y + g + for i in 0..3 do + buf.[i] <- f i + @> - checkResult command intInArr [|11; 1; 2; 3|] + checkResult command intInArr [|2; 3; 6; 7|] - testCase "Test 2" <| fun _ -> - let command = - <@ - fun (range: Range1D) (arr: ClArray) -> - let f x = - let g y = y + 1 - g x - arr.[0] <- f 2 - @> + testCase "Template Let Transformation Test 18" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + for i in 0..3 do + let f = + let g = buf.[1] + 1 + i + g + if range.GlobalID0 = 0 + then buf.[i] <- f + @> - checkResult command intInArr [|3; 1; 2; 3|] + checkResult command intInArr [|2; 3; 6; 7|] - testCase "Test 3" <| fun _ -> - let command = - <@ - fun (range: Range1D) (arr: ClArray)-> - let f x = - let g y = - y + x - g (x + 1) - arr.[0] <- f 2 - @> + testCase "Template Let Transformation Test 19" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + if range.GlobalID0 = 0 + then + for i in 0..3 do + let f x = + let g = buf.[1] + x + i + g + buf.[i] <- f 1 + @> - checkResult command intInArr [|5; 1; 2; 3|] + checkResult command intInArr [|2; 3; 6; 7|] - testCase "Test 4" <| fun _ -> - let command = - <@ - fun (range: Range1D) (arr: ClArray) -> - let gid = range.GlobalID0 - let x = - let mutable y = 0 + // TODO: perform range (1D, 2D, 3D) erasure when range is lifted. + ptestCase "Template Let Transformation Test 20" <| fun _ -> + let command = + <@ + fun (range: Range1D) (m: ClArray) -> + let f x = + range.GlobalID0 + x + m.[0] <- f 2 + @> - let addToY x = - y <- y + x + checkResult command intInArr [|2; 3; 6; 7|] +] - for i in 0..5 do - addToY arr.[gid] - y - arr.[gid] <- x - @> +let letQuotationTransformerSystemTests = testList "Let Transformation Tests Mutable Vars" [ + testCase "Test 0" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let mutable x = 1 + let f y = + x <- y + f 10 + buf.[0] <- x + @> - checkResult command intInArr [|0; 6; 12; 18|] + checkResult command intInArr [|10; 1; 2; 3|] - testCase "Test 5" <| fun _ -> - let command = - <@ - fun (range: Range1D) (arr: ClArray) -> - let gid = range.GlobalID0 + testCase "Test 1" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + let mutable x = 1 + let f y = + x <- x + y + f 10 + buf.[0] <- x + @> + + checkResult command intInArr [|11; 1; 2; 3|] - let mutable x = - if 0 > 1 then 2 else 3 + testCase "Test 2" <| fun _ -> + let command = + <@ + fun (range: Range1D) (arr: ClArray) -> + let f x = + let g y = y + 1 + g x + arr.[0] <- f 2 + @> - let mutable y = - for i in 0..4 do - x <- x + 1 - x + 1 + checkResult command intInArr [|3; 1; 2; 3|] - let z = - x + y + testCase "Test 3" <| fun _ -> + let command = + <@ + fun (range: Range1D) (arr: ClArray)-> + let f x = + let g y = + y + x + g (x + 1) + arr.[0] <- f 2 + @> - let f () = - arr.[gid] <- x + y + z - f () - @> + checkResult command intInArr [|5; 1; 2; 3|] - checkResult command intInArr [|34; 34; 34; 34|] - ] + testCase "Test 4" <| fun _ -> + let command = + <@ + fun (range: Range1D) (arr: ClArray) -> + let gid = range.GlobalID0 + let x = + let mutable y = 0 -let structTests = - ptestList "Struct tests" [ - testCase "Simple seq of struct." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 then - let b = buf.[0] - buf.[0] <- buf.[1] - buf.[1] <- b - @> + let addToY x = + y <- y + x - checkResult command [|TestStruct(1, 2.0); TestStruct(3, 4.0)|] [|TestStruct(3, 4.0); TestStruct(1, 2.0)|] + for i in 0..5 do + addToY arr.[gid] + y + arr.[gid] <- x + @> - ptestCase "Simple seq of struct changes." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - buf.[0] <- TestStruct(5, 6.0) - @> + checkResult command intInArr [|0; 6; 12; 18|] - checkResult command [|TestStruct(1, 2.0); TestStruct(3, 4.0)|] - [|TestStruct(5, 6.0); TestStruct(3, 4.0)|] + testCase "Test 5" <| fun _ -> + let command = + <@ + fun (range: Range1D) (arr: ClArray) -> + let gid = range.GlobalID0 - testCase "Simple seq of struct prop set" <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - let mutable y = buf.[0] - y.x <- 5 - buf.[0] <- y - @> + let mutable x = + if 0 > 1 then 2 else 3 - checkResult command [|TestStruct(1, 2.0)|] [|TestStruct(5, 2.0)|] + let mutable y = + for i in 0..4 do + x <- x + 1 + x + 1 - ptestCase "Simple seq of struct prop get." <| fun _ -> - let command = - <@ - fun (range: Range1D) (buf: ClArray) -> - if range.GlobalID0 = 0 - then - let mutable y = buf.[0] - y.x <- y.x + 3 - buf.[0] <- y - @> + let z = + x + y - checkResult command [|TestStruct(1, 2.0); TestStruct(3, 4.0)|] - [|TestStruct(4, 2.0); TestStruct(3, 4.0)|] + let f () = + arr.[gid] <- x + y + z + f () + @> - testCase "Nested structs 1." <| fun _ -> () - ] + checkResult command intInArr [|34; 34; 34; 34|] +] let commonApiTests = testList "Common Api Tests" [ // TODO is it correct? @@ -1274,6 +1095,165 @@ let commonApiTests = testList "Common Api Tests" [ Expect.throwsT <| fun () -> Utils.openclTranslate command |> ignore <| "Exception should be thrown" + + testCase "Check simple '|> ignore'" <| fun () -> + let command = + <@ + fun (range: Range1D) (buffer: ClArray) -> + let gid = range.GlobalID0 + atomic inc buffer.[gid] |> ignore + @> + + checkResult command intInArr (intInArr |> Array.map ((+) 1)) + + // Lambda is not supported. + ptestCase "Forward pipe" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- (1.25f |> int) + @> + checkResult command intInArr [|1; 1; 2; 3|] + + // Lambda is not supported. + ptestCase "Backward pipe" <| fun _ -> + let command = + <@ + fun (range: Range1D) (buf: ClArray) -> + buf.[0] <- int <| 1.25f + 2.34f + @> + checkResult command intInArr [|3; 1; 2; 3|] + + testCase "Getting value of 'int clcell' should be correct" <| fun () -> + let command = + <@ + fun (range: Range1D) (buffer: int clarray) (cell: int clcell) -> + let gid = range.GlobalID0 + buffer.[gid] <- cell.Value + @> + + let value = 10 + let expected = Array.replicate defaultInArrayLength value + + let actual = + opencl { + use! cell = ClCell.toDevice 10 + use! buffer = ClArray.alloc defaultInArrayLength + do! runCommand command <| fun it -> + it + <| default1D + <| buffer + <| cell + + return! ClArray.toHost buffer + } + |> ClTask.runSync context + + "Arrays should be equal" + |> Expect.sequenceEqual actual expected + + // TODO test on getting Value property of non-clcell type + // TODO test on getting Item property on non-clarray type + + testCase "Setting value of 'int clcell' should be correct" <| fun () -> + let value = 10 + let command = + <@ + fun (range: Range1D) (cell: int clcell) -> + cell.Value <- value + @> + + let actual = + opencl { + use! cell = ClCell.toDevice value + do! runCommand command <| fun it -> + it + <| default1D + <| cell + + return! ClCell.toHost cell + } + |> ClTask.runSync context + + "Arrays should be equal" + |> Expect.equal actual value + + testCase "Using 'int clcell' from inner function should work correctly" <| fun () -> + let value = 10 + let command = + <@ + fun (range: Range1D) (cell: int clcell) -> + let f () = + let x = cell.Value + cell.Value <- x + + f () + @> + + let actual = + opencl { + use! cell = ClCell.toDevice value + do! runCommand command <| fun it -> + it + <| default1D + <| cell + + return! ClCell.toHost cell + } + |> ClTask.runSync context + + "Arrays should be equal" + |> Expect.equal actual value + + testCase "Using 'int clcell' with native atomic operation should be correct" <| fun () -> + let value = 10 + let command = + <@ + fun (range: Range1D) (cell: int clcell) -> + atomic (+) cell.Value value |> ignore + @> + + let expected = value * default1D.GlobalWorkSize + + let actual = + opencl { + use! cell = ClCell.toDevice 0 + do! runCommand command <| fun it -> + it + <| default1D + <| cell + + return! ClCell.toHost cell + } + |> ClTask.runSync context + + "Arrays should be equal" + |> Expect.equal actual expected + + ptestCase "Using 'int clcell' with spinlock atomic operation should be correct" <| fun () -> + let value = 10 + let command = + <@ + fun (range: Range1D) (cell: int clcell) -> + atomic (fun x -> x + value) cell.Value |> ignore + @> + + let expected = value * default1D.GlobalWorkSize + + let actual = + opencl { + use! cell = ClCell.toDevice 0 + do! runCommand command <| fun it -> + it + <| default1D + <| cell + + return! ClCell.toHost cell + } + |> ClTask.runSync context + + "Arrays should be equal" + |> Expect.equal actual expected ] let booleanTests = testList "Boolean Tests" [ @@ -1416,10 +1396,12 @@ let parallelExecutionTests = testList "Parallel Execution Tests" [ |> Expect.sequenceEqual actual expected ] -type T1 = None1 | Some1 of int +type Option1 = + | None1 + | Some1 of int let simpleDUTests = testList "Simple tests on discriminated unions" [ - ptestCase "Option with F#-native syntax" <| fun () -> + testCase "Option with F#-native syntax" <| fun () -> let rnd = System.Random() let input1 = Array.init 100_000 (fun i -> rnd.Next()) let input2 = Array.init 100_000 (fun i -> rnd.Next()) @@ -1434,7 +1416,7 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ let i = ndRange.GlobalID0 if i < inputArrayLength then let x = if input1.[i] < 0 then None else Some input1.[i] - let y = if input2.[i] < 0 then None else Some input1.[i] + let y = if input2.[i] < 0 then None else Some input2.[i] output.[i] <- match (%op) x y with Some x -> x | None -> 0 @> @@ -1443,11 +1425,14 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ use! input1 = ClArray.toDevice input1 use! input2 = ClArray.toDevice input2 use! output = ClArray.alloc 100_000 - let op = <@ fun x y -> - match x with - Some x -> match y with Some y -> Some (x + y) | None -> Some x - | None -> match y with Some y -> Some y | None -> None @> - do! runCommand (add (op)) <| fun x -> + let op = + <@ fun x y -> + match x with + | Some x -> match y with Some y -> Some (x + y) | None -> Some x + | None -> match y with Some y -> Some y | None -> None + @> + + do! runCommand (add op) <| fun x -> x <| Range1D.CreateValid(input1.Length, 256) <| input1 @@ -1458,16 +1443,20 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ } |> ClTask.runSync context - let expected = Array.map2 (fun x y -> if x < 0 - then if y < 0 - then 0 - else y - else x + y) input1 input2 + let expected = + (input1, input2) + ||> Array.map2 + (fun x y -> + if x < 0 then + if y < 0 then 0 else y + else + x + y + ) "Arrays should be equal" |> Expect.sequenceEqual actual expected - ptestCase "Option with simplified syntax" <| fun () -> + testCase "Option with simplified syntax" <| fun () -> let rnd = System.Random() let input1 = Array.init 100_000 (fun i -> rnd.Next()) let input2 = Array.init 100_000 (fun i -> rnd.Next()) @@ -1486,7 +1475,7 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ if input1.[i] >= 0 then x <- Some input1.[i] if input2.[i] >= 0 then y <- Some input2.[i] match (%op) x y with - Some x -> output.[i] <- x + | Some x -> output.[i] <- x | None -> output.[i] <- 0 @> @@ -1495,11 +1484,16 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ use! input1 = ClArray.toDevice input1 use! input2 = ClArray.toDevice input2 use! output = ClArray.alloc 100_000 - let op = <@ fun x y -> - match x with - Some x -> match y with Some y -> Some (x + y) | None -> Some x - | None -> match y with Some y -> Some y | None -> None @> - do! runCommand (add (op)) <| fun x -> + let op = + <@ fun x y -> + match x, y with + | Some x, Some y -> Some (x + y) + | Some x, None -> Some x + | None, Some y -> Some y + | None, None -> None + @> + + do! runCommand (add op) <| fun x -> x <| Range1D.CreateValid(input1.Length, 256) <| input1 @@ -1510,21 +1504,25 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ } |> ClTask.runSync context - let expected = Array.map2 (fun x y -> if x < 0 - then if y < 0 - then 0 - else y - else x + y) input1 input2 + let expected = + (input1, input2) + ||> Array.map2 + (fun x y -> + if x < 0 then + if y < 0 then 0 else y + else + x + y + ) "Arrays should be equal" |> Expect.sequenceEqual actual expected - ptestCase "Simple custom non-generic DU" <| fun () -> + testCase "Simple custom non-generic DU" <| fun () -> let rnd = System.Random() let input1 = Array.init 100_000 (fun i -> rnd.Next()) let input2 = Array.init 100_000 (fun i -> rnd.Next()) let inputArrayLength = input1.Length - let add (op:Expr T1 -> T1>) = + let add (op:Expr Option1 -> Option1>) = <@ fun (ndRange: Range1D) (input1: int clarray) @@ -1539,7 +1537,7 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ if input2.[i] >= 0 then y <- Some1 input2.[i] let z = (%op) x y match z with - Some1 x -> output.[i] <- x + | Some1 x -> output.[i] <- x | None1 -> output.[i] <- 0 @> @@ -1548,11 +1546,14 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ use! input1 = ClArray.toDevice input1 use! input2 = ClArray.toDevice input2 use! output = ClArray.alloc 100_000 - let op = <@ fun x y -> - match x with - Some1 x -> match y with Some1 y -> Some1 (x + y) | None1 -> Some1 x - | None1 -> match y with Some1 y -> Some1 y | None1 -> None1 @> - do! runCommand (add (op)) <| fun x -> + let op = + <@ fun x y -> + match x with + | Some1 x -> match y with Some1 y -> Some1 (x + y) | None1 -> Some1 x + | None1 -> match y with Some1 y -> Some1 y | None1 -> None1 + @> + + do! runCommand (add op) <| fun x -> x <| Range1D.CreateValid(input1.Length, 256) <| input1 @@ -1563,11 +1564,15 @@ let simpleDUTests = testList "Simple tests on discriminated unions" [ } |> ClTask.runSync context - let expected = Array.map2 (fun x y -> if x < 0 - then if y < 0 - then 0 - else y - else x + y) input1 input2 + let expected = + (input1, input2) + ||> Array.map2 + (fun x y -> + if x < 0 then + if y < 0 then 0 else y + else + x + y + ) "Arrays should be equal" |> Expect.sequenceEqual actual expected @@ -1611,16 +1616,14 @@ let tests = testList "System tests with running kernels" [ letTransformationTests letQuotationTransformerSystemTests - dataStructuresApiTests + smokeTestsOnPrimitiveTypes typeCastingTests bindingTests operatorsAndMathFunctionsTests - pipeTests controlFlowTests kernelArgumentsTests quotationInjectionTests localMemTests - structTests booleanTests parallelExecutionTests simpleDUTests diff --git a/tests/Brahma.FSharp.Tests/QuotationsTransformerTests.fs b/tests/Brahma.FSharp.Tests/QuotationsTransformerTests.fs index ac70f43e..63649e08 100644 --- a/tests/Brahma.FSharp.Tests/QuotationsTransformerTests.fs +++ b/tests/Brahma.FSharp.Tests/QuotationsTransformerTests.fs @@ -238,7 +238,7 @@ let quotationTransformerTest = let expectedKernelExpr, expectedMethods = makeMethods expected testCase name <| fun _ -> - let actualKernelExpr, actualKernelMethods = Transformer.transformQuotation expr [] + let (actualKernelExpr, actualKernelMethods) = transformQuotation expr assertMethodListsEqual actualKernelMethods expectedMethods assertExprEqual actualKernelExpr expectedKernelExpr "kernels not equals" diff --git a/tests/Brahma.FSharp.Tests/TranslatorTests.fs b/tests/Brahma.FSharp.Tests/TranslatorTests.fs index 24c2d202..4d279f36 100644 --- a/tests/Brahma.FSharp.Tests/TranslatorTests.fs +++ b/tests/Brahma.FSharp.Tests/TranslatorTests.fs @@ -60,7 +60,6 @@ let basicBinOpsTests = testList "Basic operations translation tests" [ checkCode command "Binop.Plus.gen" "Binop.Plus.cl" - testCase "Binary operations. Math." <| fun _ -> let command = <@ fun (range: Range1D) (buf: int clarray) -> @@ -74,6 +73,16 @@ let basicBinOpsTests = testList "Basic operations translation tests" [ checkCode command "Binary.Operations.Math.gen" "Binary.Operations.Math.cl" + testCase "TempVar from MAX transformation should not affect other variables" <| fun () -> + let command = + <@ + fun (range: Range1D) (buf: float clarray) -> + let tempVarY = 1. + buf.[0] <- max buf.[0] tempVarY + buf.[0] <- max buf.[0] tempVarY + @> + + checkCode command "MAX.Transformation.gen" "MAX.Transformation.cl" ] let controlFlowTests = testList "Control flow translation tests" [ @@ -697,6 +706,20 @@ let printfTests = testList "Translation of printf" [ checkCode command "Printf test 6.gen" "Printf test 6.cl" ] +let barrierTests = testList "Barrier translation tests" [ + testCase "Local barrier translation tests" <| fun () -> + let command = <@ fun (range: Range1D) -> barrierLocal () @> + checkCode command "Barrier.Local.gen" "Barrier.Local.cl" + + testCase "Global barrier translation tests" <| fun () -> + let command = <@ fun (range: Range1D) -> barrierGlobal () @> + checkCode command "Barrier.Global.gen" "Barrier.Global.cl" + + testCase "Full barrier translation tests" <| fun () -> + let command = <@ fun (range: Range1D) -> barrierFull () @> + checkCode command "Barrier.Full.gen" "Barrier.Full.cl" +] + let tests = testList "Tests for translator" [ basicLocalIdTests @@ -710,5 +733,6 @@ let tests = localMemoryTests localMemoryAllocationTests printfTests + barrierTests ] |> testSequenced