Skip to content

Commit

Permalink
Fetch thread-local information (ptls) through the current task (Julia…
Browse files Browse the repository at this point in the history
…Lang#40715)

Enables task-thread migration!

Co-authored-by: Takafumi Arakaki <aka.tkf@gmail.com>
  • Loading branch information
2 people authored and Amit Shirodkar committed Jun 9, 2021
1 parent fce0bf9 commit b1eca0f
Show file tree
Hide file tree
Showing 73 changed files with 1,387 additions and 1,157 deletions.
4 changes: 2 additions & 2 deletions base/gcutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ end
Immediately run finalizers registered for object `x`.
"""
finalize(@nospecialize(o)) = ccall(:jl_finalize_th, Cvoid, (Ptr{Cvoid}, Any,),
Core.getptls(), o)
finalize(@nospecialize(o)) = ccall(:jl_finalize_th, Cvoid, (Any, Any,),
current_task(), o)

"""
Base.GC
Expand Down
11 changes: 7 additions & 4 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,19 +619,22 @@ function enq_work(t::Task)
# 1. The Task's stack is currently being used by the scheduler for a certain thread.
# 2. There is only 1 thread.
# 3. The multiq is full (can be fixed by making it growable).
if t.sticky || tid != 0 || Threads.nthreads() == 1
if t.sticky || Threads.nthreads() == 1
if tid == 0
tid = Threads.threadid()
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
end
push!(Workqueues[tid], t)
else
tid = 0
if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0
# if multiq is full, give to a random thread (TODO fix)
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
if tid == 0
tid = mod(time_ns() % Int, Threads.nthreads()) + 1
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
end
push!(Workqueues[tid], t)
else
tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
Expand Down
4 changes: 2 additions & 2 deletions cli/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ include $(JULIAHOME)/Make.inc
include $(JULIAHOME)/deps/llvm-ver.make


HEADERS := $(addprefix $(SRCDIR)/,jl_exports.h loader.h) $(addprefix $(JULIAHOME)/src/,support/platform.h support/dirpath.h jl_exported_data.inc jl_exported_funcs.inc)
HEADERS := $(addprefix $(SRCDIR)/,jl_exports.h loader.h) $(addprefix $(JULIAHOME)/src/,julia_fasttls.h support/platform.h support/dirpath.h jl_exported_data.inc jl_exported_funcs.inc)

LOADER_CFLAGS = $(JCFLAGS) -I$(BUILDROOT)/src -I$(JULIAHOME)/src -I$(JULIAHOME)/src/support -I$(build_includedir) -ffreestanding
LOADER_LDFLAGS = $(JLDFLAGS) -ffreestanding -L$(build_shlibdir) -L$(build_libdir)
Expand Down Expand Up @@ -116,7 +116,7 @@ endif
$(build_shlibdir)/libjulia-debug.$(JL_MAJOR_MINOR_SHLIB_EXT): $(LIB_DOBJS) $(SRCDIR)/list_strip_symbols.h | $(build_shlibdir) $(build_libdir)
@$(call PRINT_LINK, $(CC) $(call IMPLIB_FLAGS,$@.tmp) $(LOADER_CFLAGS) -DLIBRARY_EXPORTS -shared $(DEBUGFLAGS) $(LIB_DOBJS) -o $@ \
$(JLIBLDFLAGS) $(LOADER_LDFLAGS) $(RPATH_LIB) $(call SONAME_FLAGS,libjulia-debug.$(JL_MAJOR_SHLIB_EXT)))
@$(INSTALL_NAME_CMD)libjulia-debug.$(SHLIB_EXT) $@.tmp
@$(INSTALL_NAME_CMD)libjulia-debug.$(SHLIB_EXT) $@
ifeq ($(OS), WINNT)
@$(call PRINT_ANALYZE, $(OBJCOPY) $(build_libdir)/$(notdir $@).tmp.a $(STRIP_EXPORTED_FUNCS) $(build_libdir)/$(notdir $@).a && rm $(build_libdir)/$(notdir $@).tmp.a)
endif
Expand Down
16 changes: 1 addition & 15 deletions cli/loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/* Bring in definitions for `_OS_X_`, `PATH_MAX` and `PATHSEPSTRING`, `jl_ptls_t`, etc... */
#include "../src/support/platform.h"
#include "../src/support/dirpath.h"
#include "../src/julia_fasttls.h"

#ifdef _OS_WINDOWS_
/* We need to reimplement a bunch of standard library stuff on windows,
Expand Down Expand Up @@ -43,15 +44,6 @@
#include <dlfcn.h>
#endif

// Borrow definitions from `julia.h`
#if defined(__GNUC__)
# define JL_CONST_FUNC __attribute__((const))
#elif defined(_COMPILER_MICROSOFT_)
# define JL_CONST_FUNC __declspec(noalias)
#else
# define JL_CONST_FUNC
#endif

// Borrow definition from `support/dtypes.h`
#ifdef _OS_WINDOWS_
# ifdef LIBRARY_EXPORTS
Expand All @@ -68,12 +60,6 @@
# endif
#define JL_HIDDEN __attribute__ ((visibility("hidden")))
#endif
#ifdef JL_DEBUG_BUILD
#define JL_NAKED __attribute__ ((naked,no_stack_protector))
#else
#define JL_NAKED __attribute__ ((naked))
#endif

/*
* DEP_LIBS is our list of dependent libraries that must be loaded before `libjulia`.
* Note that order matters, as each entry will be opened in-order. We define here a
Expand Down
15 changes: 3 additions & 12 deletions cli/loader_exe.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,14 @@ extern "C" {
#include "loader_win_utils.c"
#endif

/* Define ptls getter, as this cannot be defined within a shared library. */
#if !defined(_OS_WINDOWS_) && !defined(_OS_DARWIN_)
JL_DLLEXPORT JL_CONST_FUNC void * jl_get_ptls_states_static(void)
{
/* Because we can't #include <julia.h> in this file, we define a TLS state object with
* hopefully enough room; at last check, the `jl_tls_states_t` struct was <16KB. */
static __attribute__((tls_model("local-exec"))) __thread char tls_states[32768];
return &tls_states;
}
#endif
JULIA_DEFINE_FAST_TLS

#ifdef _OS_WINDOWS_
int mainCRTStartup(void)
{
int argc;
LPWSTR * wargv = CommandLineToArgv(GetCommandLine(), &argc);
char ** argv = (char **)malloc(sizeof(char *)*(argc+ 1));
char ** argv = (char **)malloc(sizeof(char*) * (argc + 1));
setup_stdio();
#else
int main(int argc, char * argv[])
Expand All @@ -36,7 +27,7 @@ int main(int argc, char * argv[])

// Convert Windows wchar_t values to UTF8
#ifdef _OS_WINDOWS_
for (int i=0; i<argc; i++) {
for (int i = 0; i < argc; i++) {
size_t max_arg_len = 4*wcslen(wargv[i]);
argv[i] = (char *)malloc(max_arg_len);
if (!wchar_to_utf8(wargv[i], argv[i], max_arg_len)) {
Expand Down
15 changes: 8 additions & 7 deletions cli/loader_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,18 @@ JL_DLLEXPORT int jl_load_repl(int argc, char * argv[]) {
}
// Next, if we're on Linux/FreeBSD, set up fast TLS.
#if !defined(_OS_WINDOWS_) && !defined(_OS_DARWIN_)
void (*jl_set_ptls_states_getter)(void *) = lookup_symbol(libjulia_internal, "jl_set_ptls_states_getter");
if (jl_set_ptls_states_getter == NULL) {
jl_loader_print_stderr("ERROR: Cannot find jl_set_ptls_states_getter() function within libjulia-internal!\n");
void (*jl_pgcstack_setkey)(void*, void*(*)(void)) = lookup_symbol(libjulia_internal, "jl_pgcstack_setkey");
if (jl_pgcstack_setkey == NULL) {
jl_loader_print_stderr("ERROR: Cannot find jl_pgcstack_setkey() function within libjulia-internal!\n");
exit(1);
}
void * (*fptr)(void) = lookup_symbol(RTLD_DEFAULT, "jl_get_ptls_states_static");
if (fptr == NULL) {
jl_loader_print_stderr("ERROR: Cannot find jl_get_ptls_states_static(), must define this symbol within calling executable!\n");
void *fptr = lookup_symbol(RTLD_DEFAULT, "jl_get_pgcstack_static");
void *(*key)(void) = lookup_symbol(RTLD_DEFAULT, "jl_pgcstack_addr_static");
if (fptr == NULL || key == NULL) {
jl_loader_print_stderr("ERROR: Cannot find jl_get_pgcstack_static(), must define this symbol within calling executable!\n");
exit(1);
}
jl_set_ptls_states_getter((void *)fptr);
jl_pgcstack_setkey(fptr, key);
#endif

// Load the repl entrypoint symbol and jump into it!
Expand Down
4 changes: 2 additions & 2 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ SRCS += $(RUNTIME_SRCS)

# headers are used for dependency tracking, while public headers will be part of the dist
UV_HEADERS :=
HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h tls.h locks.h atomics.h julia_internal.h options.h timing.h)
PUBLIC_HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h tls.h locks.h atomics.h julia_gcext.h)
HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h julia_fasttls.h locks.h atomics.h julia_internal.h options.h timing.h)
PUBLIC_HEADERS := $(BUILDDIR)/julia_version.h $(wildcard $(SRCDIR)/support/*.h) $(addprefix $(SRCDIR)/,julia.h julia_assert.h julia_threads.h julia_fasttls.h locks.h atomics.h julia_gcext.h)
ifeq ($(USE_SYSTEM_LIBUV),0)
UV_HEADERS += uv.h
UV_HEADERS += uv/*.h
Expand Down
43 changes: 23 additions & 20 deletions src/array.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ size_t jl_arr_xtralloc_limit = 0;
static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
int8_t isunboxed, int8_t hasptr, int8_t isunion, int8_t zeroinit, int elsz)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
size_t i, tot, nel=1;
void *data;
jl_array_t *a;
Expand Down Expand Up @@ -119,7 +119,7 @@ static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
size_t doffs = tsz;
tsz += tot;
tsz = JL_ARRAY_ALIGN(tsz, JL_SMALL_BYTE_ALIGNMENT); // align whole object
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
// No allocation or safepoint allowed after this
a->flags.how = 0;
data = (char*)a + doffs;
Expand All @@ -129,10 +129,10 @@ static jl_array_t *_new_array_(jl_value_t *atype, uint32_t ndims, size_t *dims,
data = jl_gc_managed_malloc(tot);
// Allocate the Array **after** allocating the data
// to make sure the array is still young
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
// No allocation or safepoint allowed after this
a->flags.how = 2;
jl_gc_track_malloced_array(ptls, a);
jl_gc_track_malloced_array(ct->ptls, a);
}
a->flags.pooled = tsz <= GC_MAX_SZCLASS;

Expand Down Expand Up @@ -213,7 +213,7 @@ static inline int is_ntuple_long(jl_value_t *v)
JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,
jl_value_t *_dims)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
jl_array_t *a;
size_t ndims = jl_nfields(_dims);
assert(is_ntuple_long(_dims));
Expand All @@ -222,7 +222,7 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,

int ndimwords = jl_array_ndimwords(ndims);
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords * sizeof(size_t) + sizeof(void*), JL_SMALL_BYTE_ALIGNMENT);
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
// No allocation or safepoint allowed after this
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
a->flags.ndims = ndims;
Expand Down Expand Up @@ -298,12 +298,12 @@ JL_DLLEXPORT jl_array_t *jl_reshape_array(jl_value_t *atype, jl_array_t *data,

JL_DLLEXPORT jl_array_t *jl_string_to_array(jl_value_t *str)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
jl_array_t *a;

int ndimwords = jl_array_ndimwords(1);
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t) + sizeof(void*), JL_SMALL_BYTE_ALIGNMENT);
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, jl_array_uint8_type);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, jl_array_uint8_type);
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
a->flags.ndims = 1;
a->offset = 0;
Expand All @@ -327,7 +327,7 @@ JL_DLLEXPORT jl_array_t *jl_string_to_array(jl_value_t *str)
JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
size_t nel, int own_buffer)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
jl_array_t *a;
jl_value_t *eltype = jl_tparam0(atype);

Expand All @@ -350,7 +350,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,

int ndimwords = jl_array_ndimwords(1);
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
// No allocation or safepoint allowed after this
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
a->data = data;
Expand All @@ -365,7 +365,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
a->flags.isaligned = 0; // TODO: allow passing memalign'd buffers
if (own_buffer) {
a->flags.how = 2;
jl_gc_track_malloced_array(ptls, a);
jl_gc_track_malloced_array(ct->ptls, a);
jl_gc_count_allocd(nel*elsz + (elsz == 1 ? 1 : 0));
}
else {
Expand All @@ -381,7 +381,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array_1d(jl_value_t *atype, void *data,
JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
jl_value_t *_dims, int own_buffer)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
size_t nel = 1;
jl_array_t *a;
size_t ndims = jl_nfields(_dims);
Expand Down Expand Up @@ -417,7 +417,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,

int ndimwords = jl_array_ndimwords(ndims);
int tsz = JL_ARRAY_ALIGN(sizeof(jl_array_t) + ndimwords*sizeof(size_t), JL_CACHE_BYTE_ALIGNMENT);
a = (jl_array_t*)jl_gc_alloc(ptls, tsz, atype);
a = (jl_array_t*)jl_gc_alloc(ct->ptls, tsz, atype);
// No allocation or safepoint allowed after this
a->flags.pooled = tsz <= GC_MAX_SZCLASS;
a->data = data;
Expand All @@ -433,7 +433,7 @@ JL_DLLEXPORT jl_array_t *jl_ptr_to_array(jl_value_t *atype, void *data,
a->flags.isaligned = 0;
if (own_buffer) {
a->flags.how = 2;
jl_gc_track_malloced_array(ptls, a);
jl_gc_track_malloced_array(ct->ptls, a);
jl_gc_count_allocd(nel*elsz + (elsz == 1 ? 1 : 0));
}
else {
Expand Down Expand Up @@ -519,7 +519,8 @@ JL_DLLEXPORT jl_value_t *jl_pchar_to_string(const char *str, size_t len)
jl_throw(jl_memory_exception);
if (len == 0)
return jl_an_empty_string;
jl_value_t *s = jl_gc_alloc_(jl_get_ptls_states(), sz, jl_string_type); // force inlining
jl_task_t *ct = jl_current_task;
jl_value_t *s = jl_gc_alloc_(ct->ptls, sz, jl_string_type); // force inlining
*(size_t*)s = len;
memcpy((char*)s + sizeof(size_t), str, len);
((char*)s + sizeof(size_t))[len] = 0;
Expand All @@ -533,7 +534,8 @@ JL_DLLEXPORT jl_value_t *jl_alloc_string(size_t len)
jl_throw(jl_memory_exception);
if (len == 0)
return jl_an_empty_string;
jl_value_t *s = jl_gc_alloc_(jl_get_ptls_states(), sz, jl_string_type); // force inlining
jl_task_t *ct = jl_current_task;
jl_value_t *s = jl_gc_alloc_(ct->ptls, sz, jl_string_type); // force inlining
*(size_t*)s = len;
((char*)s + sizeof(size_t))[len] = 0;
return s;
Expand Down Expand Up @@ -672,7 +674,7 @@ JL_DLLEXPORT void jl_arrayunset(jl_array_t *a, size_t i)
// the **beginning** of the new buffer.
static int NOINLINE array_resize_buffer(jl_array_t *a, size_t newlen)
{
jl_ptls_t ptls = jl_get_ptls_states();
jl_task_t *ct = jl_current_task;
assert(!a->flags.isshared || a->flags.how == 3);
size_t elsz = a->elsize;
size_t nbytes = newlen * elsz;
Expand Down Expand Up @@ -714,12 +716,12 @@ static int NOINLINE array_resize_buffer(jl_array_t *a, size_t newlen)
newbuf = 1;
if (nbytes >= MALLOC_THRESH) {
a->data = jl_gc_managed_malloc(nbytes);
jl_gc_track_malloced_array(ptls, a);
jl_gc_track_malloced_array(ct->ptls, a);
a->flags.how = 2;
a->flags.isaligned = 1;
}
else {
a->data = jl_gc_alloc_buf(ptls, nbytes);
a->data = jl_gc_alloc_buf(ct->ptls, nbytes);
a->flags.how = 1;
jl_gc_wb_buf(a, a->data, nbytes);
}
Expand Down Expand Up @@ -1008,8 +1010,9 @@ STATIC_INLINE void jl_array_shrink(jl_array_t *a, size_t dec)
typetagdata = (char*)malloc_s(a->nrows);
memcpy(typetagdata, jl_array_typetagdata(a), a->nrows);
}
jl_task_t *ct = jl_current_task;
char *originaldata = (char*) a->data - a->offset * a->elsize;
char *newdata = (char*)jl_gc_alloc_buf(jl_get_ptls_states(), newbytes);
char *newdata = (char*)jl_gc_alloc_buf(ct->ptls, newbytes);
jl_gc_wb_buf(a, newdata, newbytes);
a->maxsize -= dec;
if (isbitsunion) {
Expand Down
Loading

0 comments on commit b1eca0f

Please sign in to comment.