diff --git a/CHANGES.md b/CHANGES.md index 25e71e0..9ac4bdf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,7 @@ - `get_free_and_total_mem`. - Multiple missing `sexp_of` conversions. -- `cuda_call_callback` to help in debugging. +- `cuda_call_hook` to help in debugging. - `is_success` functions. ## [0.5.0] 2024-09-25 diff --git a/cudajit.ml b/cudajit.ml index 7aae13f..79283da 100644 --- a/cudajit.ml +++ b/cudajit.ml @@ -96,10 +96,10 @@ let cuda_error_printer = function let () = Printexc.register_printer cuda_error_printer let is_success = function CUDA_SUCCESS -> true | _ -> false -let cuda_call_callback : (message:string -> status:result -> unit) option ref = ref None +let cuda_call_hook : (message:string -> status:result -> unit) option ref = ref None let check message status = - (match !cuda_call_callback with None -> () | Some callback -> callback ~message ~status); + (match !cuda_call_hook with None -> () | Some callback -> callback ~message ~status); if status <> CUDA_SUCCESS then raise @@ Cuda_error { status; message } let init ?(flags = 0) () = check "cu_init" @@ Cuda.cu_init flags diff --git a/cudajit.mli b/cudajit.mli index 052568c..ed45a73 100644 --- a/cudajit.mli +++ b/cudajit.mli @@ -56,7 +56,7 @@ exception Cuda_error of { status : result; message : string } val is_success : result -> bool -val cuda_call_callback : (message:string -> status:result -> unit) option ref +val cuda_call_hook : (message:string -> status:result -> unit) option ref (** The function called after every {!Cuda_ffi.Bindings.Functions} call. [message] is the snake-case variant of the corresponding CUDA function name. *) diff --git a/test/saxpy.ml b/test/saxpy.ml index c0c1bc0..9d5b316 100644 --- a/test/saxpy.ml +++ b/test/saxpy.ml @@ -16,11 +16,11 @@ let%expect_test "SAXPY" = let num_blocks = 32 in let num_threads = 128 in let module Cu = Cudajit in + Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Printf.printf "%s\n" message); let prog = Cu.Nvrtc.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ] ~with_debug:true in - Cu.cuda_call_callback := Some (fun ~message ~status:_ -> Printf.printf "%s\n" message); Cu.init (); if Cu.Device.get_count () > 0 then ( let device = Cu.Device.get ~ordinal:0 in @@ -1449,4 +1449,5 @@ let%expect_test "SAXPY" = 5.1 * 4087.0 + 8174.0 = 29017.70; 5.1 * 4088.0 + 8176.0 = 29024.80; 5.1 * 4089.0 + 8178.0 = 29031.90; 5.1 * 4090.0 + 8180.0 = 29039.00; 5.1 * 4091.0 + 8182.0 = 29046.10; 5.1 * 4092.0 + 8184.0 = 29053.20; 5.1 * 4093.0 + 8186.0 = 29060.30; 5.1 * 4094.0 + 8188.0 = 29067.40; 5.1 * 4095.0 + 8190.0 = 29074.50; - |}]) + |}]); + Cu.cuda_call_hook := None