Skip to content

Commit

Permalink
Rename cuda_call_callback -> cuda_call_hook
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 19, 2024
1 parent aa06107 commit 4f9d96e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- `get_free_and_total_mem`.
- Multiple missing `sexp_of` conversions.
- `cuda_call_callback` to help in debugging.
- `cuda_call_hook` to help in debugging.
- `is_success` functions.

## [0.5.0] 2024-09-25
Expand Down
4 changes: 2 additions & 2 deletions cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ let cuda_error_printer = function

let () = Printexc.register_printer cuda_error_printer
let is_success = function CUDA_SUCCESS -> true | _ -> false
let cuda_call_callback : (message:string -> status:result -> unit) option ref = ref None
let cuda_call_hook : (message:string -> status:result -> unit) option ref = ref None

let check message status =
(match !cuda_call_callback with None -> () | Some callback -> callback ~message ~status);
(match !cuda_call_hook with None -> () | Some callback -> callback ~message ~status);
if status <> CUDA_SUCCESS then raise @@ Cuda_error { status; message }

let init ?(flags = 0) () = check "cu_init" @@ Cuda.cu_init flags
Expand Down
2 changes: 1 addition & 1 deletion cudajit.mli
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ exception Cuda_error of { status : result; message : string }

val is_success : result -> bool

val cuda_call_callback : (message:string -> status:result -> unit) option ref
val cuda_call_hook : (message:string -> status:result -> unit) option ref
(** The function called after every {!Cuda_ffi.Bindings.Functions} call. [message] is the snake-case
variant of the corresponding CUDA function name. *)

Expand Down
5 changes: 3 additions & 2 deletions test/saxpy.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ let%expect_test "SAXPY" =
let num_blocks = 32 in
let num_threads = 128 in
let module Cu = Cudajit in
Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Printf.printf "%s\n" message);
let prog =
Cu.Nvrtc.compile_to_ptx ~cu_src:kernel ~name:"saxpy" ~options:[ "--use_fast_math" ]
~with_debug:true
in
Cu.cuda_call_callback := Some (fun ~message ~status:_ -> Printf.printf "%s\n" message);
Cu.init ();
if Cu.Device.get_count () > 0 then (
let device = Cu.Device.get ~ordinal:0 in
Expand Down Expand Up @@ -1449,4 +1449,5 @@ let%expect_test "SAXPY" =
5.1 * 4087.0 + 8174.0 = 29017.70; 5.1 * 4088.0 + 8176.0 = 29024.80; 5.1 * 4089.0 + 8178.0 = 29031.90;
5.1 * 4090.0 + 8180.0 = 29039.00; 5.1 * 4091.0 + 8182.0 = 29046.10; 5.1 * 4092.0 + 8184.0 = 29053.20;
5.1 * 4093.0 + 8186.0 = 29060.30; 5.1 * 4094.0 + 8188.0 = 29067.40; 5.1 * 4095.0 + 8190.0 = 29074.50;
|}])
|}]);
Cu.cuda_call_hook := None

0 comments on commit 4f9d96e

Please sign in to comment.