diff --git a/lib/multi_channel.ml b/lib/multi_channel.ml index 7bbe3e1..95725e2 100644 --- a/lib/multi_channel.ml +++ b/lib/multi_channel.ml @@ -33,7 +33,6 @@ type dls_state = { } type 'a t = { - mask: int; channels: 'a Ws_deque.t array; waiters: (waiting_status ref * mutex_condvar ) Chan.t; next_domain_id: int Atomic.t; @@ -54,11 +53,7 @@ let rec log2 n = if n <= 1 then 0 else 1 + (log2 (n asr 1)) let make ?(recv_block_spins = 2048) n = - let sz = Int.shift_left 1 ((log2 (n-1))+1) in - assert ((sz >= n) && (sz > 0)); - assert (Int.logand sz (sz-1) == 0); - { mask = sz - 1; - channels = Array.init sz (fun _ -> Ws_deque.create ()); + { channels = Array.init n (fun _ -> Ws_deque.create ()); waiters = Chan.make_unbounded (); next_domain_id = Atomic.make 0; recv_block_spins; @@ -71,9 +66,10 @@ let register_domain mchan = id let init_domain_state mchan dls_state = - let id = (register_domain mchan) in + let id = register_domain mchan in + let len = Array.length mchan.channels in dls_state.id <- id; - dls_state.steal_offsets <- Array.init ((Array.length mchan.channels)-1) (fun i -> i+1); + dls_state.steal_offsets <- Array.init (len - 1) (fun i -> (id + i + 1) mod len); dls_state [@@inline never] @@ -125,7 +121,7 @@ let rec recv_poll_loop mchan dls cur_offset = else begin let idx = cur_offset + (Random.State.int dls.rng_state k) in let t = Array.unsafe_get offsets idx in - let channel = Array.unsafe_get mchan.channels (Int.logand (dls.id + t) mchan.mask) in + let channel = Array.unsafe_get mchan.channels t in try Ws_deque.steal channel with