diff --git a/base/task.jl b/base/task.jl index 1985fc03b9611..5f75b97249127 100644 --- a/base/task.jl +++ b/base/task.jl @@ -409,8 +409,8 @@ end function enq_work(t::Task) (t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable") - if t.sticky - tid = Threads.threadid(t) + tid = Threads.threadid(t) + if t.sticky || tid != 0 if tid == 0 tid = Threads.threadid() end diff --git a/src/julia_threads.h b/src/julia_threads.h index 1da831bafe4de..bc6ca7a395408 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -155,7 +155,8 @@ struct _jl_tls_states_t { volatile int8_t in_finalizer; int8_t disable_gc; volatile sig_atomic_t defer_signal; - struct _jl_task_t *volatile current_task; + struct _jl_task_t *current_task; + struct _jl_task_t *previous_task; struct _jl_task_t *root_task; void *stackbase; size_t stacksize; diff --git a/src/partr.c b/src/partr.c index 90dcc08436de6..c3bc397652bda 100644 --- a/src/partr.c +++ b/src/partr.c @@ -140,7 +140,7 @@ static inline jl_task_t *multiq_deletemin(void) uint64_t rn1 = 0, rn2; int16_t i, prio1, prio2; jl_task_t *task; - + retry: for (i = 0; i < heap_p; ++i) { rn1 = cong(heap_p, cong_unbias, &ptls->rngseed); rn2 = cong(heap_p, cong_unbias, &ptls->rngseed); @@ -162,6 +162,12 @@ static inline jl_task_t *multiq_deletemin(void) return NULL; task = heaps[rn1].tasks[0]; + if (jl_atomic_load_acquire(&task->tid) != ptls->tid) { + if (jl_atomic_compare_exchange(&task->tid, -1, ptls->tid) != -1) { + jl_mutex_unlock_nogc(&heaps[rn1].lock); + goto retry; + } + } heaps[rn1].tasks[0] = heaps[rn1].tasks[--heaps[rn1].ntasks]; heaps[rn1].tasks[heaps[rn1].ntasks] = NULL; prio1 = INT16_MAX; @@ -244,8 +250,13 @@ JL_DLLEXPORT void jl_enqueue_task(jl_task_t *task) static jl_task_t *get_next_task(jl_value_t *getsticky) { jl_task_t *task = (jl_task_t*)jl_apply(&getsticky, 1); - if (jl_typeis(task, jl_task_type)) + if (jl_typeis(task, jl_task_type)) { + int self = jl_get_ptls_states()->tid; + if (jl_atomic_load_acquire(&task->tid) != self) { + jl_atomic_compare_exchange(&task->tid, -1, self); + } return task; + } return multiq_deletemin(); } diff --git a/src/task.c b/src/task.c index ba3a0617a808f..fb31438d646a7 100644 --- a/src/task.c +++ b/src/task.c @@ -117,9 +117,8 @@ static void NOINLINE save_stack(jl_ptls_t ptls, jl_task_t *lastt, jl_task_t **pt jl_gc_wb_back(lastt); } -static void NOINLINE JL_NORETURN restore_stack(jl_ptls_t ptls, char *p) +static void NOINLINE JL_NORETURN restore_stack(jl_task_t *t, jl_ptls_t ptls, char *p) { - jl_task_t *t = ptls->current_task; size_t nb = t->copy_stack; char *_x = (char*)ptls->stackbase - nb; if (!p) { @@ -128,16 +127,15 @@ static void NOINLINE JL_NORETURN restore_stack(jl_ptls_t ptls, char *p) if ((char*)&_x > _x) { p = (char*)alloca((char*)&_x - _x); } - restore_stack(ptls, p); // pass p to ensure the compiler can't tailcall this or avoid the alloca + restore_stack(t, ptls, p); // pass p to ensure the compiler can't tailcall this or avoid the alloca } assert(t->stkbuf != NULL); memcpy_a16((uint64_t*)_x, (uint64_t*)t->stkbuf, nb); // destroys all but the current stackframe jl_set_fiber(&t->ctx); abort(); // unreachable } -static void restore_stack2(jl_ptls_t ptls, jl_task_t *lastt) +static void restore_stack2(jl_task_t *t, jl_ptls_t ptls, jl_task_t *lastt) { - jl_task_t *t = ptls->current_task; size_t nb = t->copy_stack; char *_x = (char*)ptls->stackbase - nb; assert(t->stkbuf != NULL); @@ -230,11 +228,6 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) jl_task_t *t = *pt; assert(t != ptls->current_task); jl_task_t *lastt = ptls->current_task; -#ifdef ENABLE_TIMINGS - jl_timing_block_t *blk = lastt->timing_stack; - if (blk) - jl_timing_block_stop(blk); -#endif #ifdef JULIA_ENABLE_THREADING // If the current task is not holding any locks, free the locks list // so that it can be GC'd without leaking memory @@ -279,11 +272,6 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) save_stack(ptls, lastt, pt); // allocates (gc-safepoint, and can also fail) if (jl_setjmp(lastt->ctx.uc_mcontext, 0)) { // TODO: mutex unlock the thread we just switched from -#ifdef ENABLE_TIMINGS - assert(blk == ptls->current_task->timing_stack); - if (blk) - jl_timing_block_start(blk); -#endif return; } } @@ -298,11 +286,8 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) ptls->pgcstack = t->gcstack; ptls->world_age = t->world_age; t->gcstack = NULL; + ptls->previous_task = lastt; ptls->current_task = t; - if (!lastt->sticky) - // release lastt to run on any tid - lastt->tid = -1; - t->tid = ptls->tid; jl_ucontext_t *lastt_ctx = (killed ? NULL : &lastt->ctx); #ifdef COPY_STACKS @@ -317,11 +302,11 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) #ifdef COPY_STACKS if (t->copy_stack) { if (lastt_ctx) - restore_stack2(ptls, lastt); + restore_stack2(t, ptls, lastt); else if (lastt->copy_stack) - restore_stack(ptls, NULL); // (doesn't return) + restore_stack(t, ptls, NULL); // (doesn't return) else - restore_stack(ptls, (char*)1); // (doesn't return) + restore_stack(t, ptls, (char*)1); // (doesn't return) } else #endif @@ -336,34 +321,63 @@ static void ctx_switch(jl_ptls_t ptls, jl_task_t **pt) else jl_start_fiber(lastt_ctx, &t->ctx); } - // TODO: mutex unlock the thread we just switched from -#ifdef ENABLE_TIMINGS - assert(blk == ptls->current_task->timing_stack); - if (blk) - jl_timing_block_start(blk); -#endif +} + +static jl_ptls_t NOINLINE refetch_ptls(void) +{ + return jl_get_ptls_states(); } JL_DLLEXPORT void jl_switchto(jl_task_t **pt) { jl_ptls_t ptls = jl_get_ptls_states(); jl_task_t *t = *pt; - if (t == ptls->current_task) { + jl_task_t *ct = ptls->current_task; + if (t == ct) { return; } if (t->state == done_sym || t->state == failed_sym || (t->started && t->stkbuf == NULL)) { - ptls->current_task->exception = t->exception; - ptls->current_task->result = t->result; + ct->exception = t->exception; + ct->result = t->result; return; } if (ptls->in_finalizer) jl_error("task switch not allowed from inside gc finalizer"); if (ptls->in_pure_callback) jl_error("task switch not allowed from inside staged nor pure functions"); + if (t->sticky && jl_atomic_load_acquire(&t->tid) == -1) { + // manually yielding to a task + if (jl_atomic_compare_exchange(&t->tid, -1, ptls->tid) != -1) + jl_error("cannot switch to task running on another thread"); + } + else if (t->tid != ptls->tid) { + jl_error("cannot switch to task running on another thread"); + } sig_atomic_t defer_signal = ptls->defer_signal; int8_t gc_state = jl_gc_unsafe_enter(ptls); + +#ifdef ENABLE_TIMINGS + jl_timing_block_t *blk = ct->timing_stack; + if (blk) + jl_timing_block_stop(blk); +#endif + ctx_switch(ptls, pt); + + ptls = refetch_ptls(); + t = ptls->previous_task; + assert(t->tid == ptls->tid); + if (!t->sticky && !t->copy_stack) + t->tid = -1; + ct = ptls->current_task; + +#ifdef ENABLE_TIMINGS + assert(blk == ct->timing_stack); + if (blk) + jl_timing_block_start(blk); +#endif + jl_gc_unsafe_leave(ptls, gc_state); sig_atomic_t other_defer_signal = ptls->defer_signal; ptls->defer_signal = defer_signal; @@ -506,7 +520,7 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion t->stkbuf = NULL; t->started = 0; t->prio = -1; - t->tid = 0; + t->tid = -1; #ifdef ENABLE_TIMINGS t->timing_stack = NULL; #endif @@ -545,6 +559,11 @@ static void NOINLINE JL_NORETURN start_task(void) jl_ptls_t ptls = jl_get_ptls_states(); jl_task_t *t = ptls->current_task; jl_value_t *res; + + jl_task_t *pt = ptls->previous_task; + if (!pt->sticky && !pt->copy_stack) + pt->tid = -1; + t->started = 1; if (t->exception != jl_nothing) { record_backtrace(ptls);