Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

malloc wrappers: ensure thread-safe #33284

Merged
merged 1 commit into from
Sep 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4512,7 +4512,9 @@ static Function* gen_cfun_wrapper(
// for now, just use a dummy field to avoid a branch in this function
ctx.world_age_field = ctx.builder.CreateSelect(have_tls, ctx.world_age_field, dummy_world);
Value *last_age = tbaa_decorate(tbaa_gcframe, ctx.builder.CreateLoad(ctx.world_age_field));
have_tls = ctx.builder.CreateAnd(have_tls, ctx.builder.CreateIsNotNull(last_age));
Value *valid_tls = ctx.builder.CreateIsNotNull(last_age);
have_tls = ctx.builder.CreateAnd(have_tls, valid_tls);
ctx.world_age_field = ctx.builder.CreateSelect(valid_tls, ctx.world_age_field, dummy_world);
Value *world_v = ctx.builder.CreateLoad(prepare_global(jlgetworld_global));

Value *age_ok = NULL;
Expand Down
2 changes: 2 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,8 @@ static jl_value_t *jl_deserialize_value_any(jl_serializer_state *s, uint8_t tag,
int32_t nw = (sz == 0 ? 1 : (sz < 0 ? -sz : sz));
size_t nb = nw * gmp_limb_size;
void *buf = jl_gc_counted_malloc(nb);
if (buf == NULL)
jl_throw(jl_memory_exception);
ios_read(s->s, (char*)buf, nb);
jl_set_nth_field(v, 0, jl_box_int32(nw));
jl_set_nth_field(v, 1, sizefield);
Expand Down
69 changes: 34 additions & 35 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -3015,54 +3015,47 @@ JL_DLLEXPORT void jl_throw_out_of_memory_error(void)
JL_DLLEXPORT void *jl_gc_counted_malloc(size_t sz)
{
jl_ptls_t ptls = jl_get_ptls_states();
maybe_collect(ptls);
ptls->gc_num.allocd += sz;
ptls->gc_num.malloc++;
void *b = malloc(sz);
if (b == NULL)
jl_throw(jl_memory_exception);
return b;
if (ptls && ptls->world_age) {
maybe_collect(ptls);
ptls->gc_num.allocd += sz;
ptls->gc_num.malloc++;
}
return malloc(sz);
}

JL_DLLEXPORT void *jl_gc_counted_calloc(size_t nm, size_t sz)
{
jl_ptls_t ptls = jl_get_ptls_states();
maybe_collect(ptls);
ptls->gc_num.allocd += nm*sz;
ptls->gc_num.malloc++;
void *b = calloc(nm, sz);
if (b == NULL)
jl_throw(jl_memory_exception);
return b;
if (ptls && ptls->world_age) {
maybe_collect(ptls);
ptls->gc_num.allocd += nm*sz;
ptls->gc_num.malloc++;
}
return calloc(nm, sz);
}

JL_DLLEXPORT void jl_gc_counted_free_with_size(void *p, size_t sz)
{
jl_ptls_t ptls = jl_get_ptls_states();
free(p);
ptls->gc_num.freed += sz;
ptls->gc_num.freecall++;
}

// older name for jl_gc_counted_free_with_size
JL_DLLEXPORT void jl_gc_counted_free(void *p, size_t sz)
{
jl_gc_counted_free_with_size(p, sz);
if (ptls && ptls->world_age) {
ptls->gc_num.freed += sz;
ptls->gc_num.freecall++;
}
}

JL_DLLEXPORT void *jl_gc_counted_realloc_with_old_size(void *p, size_t old, size_t sz)
{
jl_ptls_t ptls = jl_get_ptls_states();
maybe_collect(ptls);
if (sz < old)
ptls->gc_num.freed += (old - sz);
else
ptls->gc_num.allocd += (sz - old);
ptls->gc_num.realloc++;
void *b = realloc(p, sz);
if (b == NULL)
jl_throw(jl_memory_exception);
return b;
if (ptls && ptls->world_age) {
maybe_collect(ptls);
if (sz < old)
ptls->gc_num.freed += (old - sz);
else
ptls->gc_num.allocd += (sz - old);
ptls->gc_num.realloc++;
}
return realloc(p, sz);
}

// allocation wrappers that save the size of allocations, to allow using
Expand All @@ -3071,16 +3064,20 @@ JL_DLLEXPORT void *jl_gc_counted_realloc_with_old_size(void *p, size_t old, size
JL_DLLEXPORT void *jl_malloc(size_t sz)
{
int64_t *p = (int64_t *)jl_gc_counted_malloc(sz + JL_SMALL_BYTE_ALIGNMENT);
if (p == NULL)
return NULL;
p[0] = sz;
return (void *)(p + 2);
return (void *)(p + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
}

JL_DLLEXPORT void *jl_calloc(size_t nm, size_t sz)
{
size_t nmsz = nm*sz;
int64_t *p = (int64_t *)jl_gc_counted_calloc(nmsz + JL_SMALL_BYTE_ALIGNMENT, 1);
if (p == NULL)
return NULL;
p[0] = nmsz;
return (void *)(p + 2);
return (void *)(p + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
}

JL_DLLEXPORT void jl_free(void *p)
Expand All @@ -3105,8 +3102,10 @@ JL_DLLEXPORT void *jl_realloc(void *p, size_t sz)
szold = pp[0] + JL_SMALL_BYTE_ALIGNMENT;
}
int64_t *pnew = (int64_t *)jl_gc_counted_realloc_with_old_size(pp, szold, sz + JL_SMALL_BYTE_ALIGNMENT);
if (pnew == NULL)
return NULL;
pnew[0] = sz;
return (void *)(pnew + 2);
return (void *)(pnew + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
}

// allocating blocks for Arrays and Strings
Expand Down
1 change: 0 additions & 1 deletion src/jl_uv.c
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ struct work_baton {
void *work_args;
void *work_retval;
notify_cb_t notify_func;
int tid;
int notify_idx;
};

Expand Down
14 changes: 14 additions & 0 deletions test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,20 @@ end

@test ccall(:jl_getpagesize, Clong, ()) == @threadcall(:jl_getpagesize, Clong, ())

# make sure our malloc/realloc/free adapters are thread-safe and repeatable
for i = 1:8
ptr = @threadcall(:jl_malloc, Ptr{Cint}, (Csize_t,), sizeof(Cint))
@test ptr != C_NULL
unsafe_store!(ptr, 3)
@test unsafe_load(ptr) == 3
ptr = @threadcall(:jl_realloc, Ptr{Cint}, (Ptr{Cint}, Csize_t,), ptr, 2 * sizeof(Cint))
@test ptr != C_NULL
unsafe_store!(ptr, 4, 2)
@test unsafe_load(ptr, 1) == 3
@test unsafe_load(ptr, 2) == 4
@threadcall(:jl_free, Cvoid, (Ptr{Cint},), ptr)
end

# Pointer finalizer (issue #15408)
let A = [1]
ccall((:set_c_int, libccalltest), Cvoid, (Cint,), 1)
Expand Down