Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/SHARP: add support for sharpv3 dt #661

Merged
merged 1 commit into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/components/tl/sharp/tl_sharp.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ static ucc_config_field_t ucc_tl_sharp_context_config_table[] = {
ucc_offsetof(ucc_tl_sharp_context_config_t, context_per_team),
UCC_CONFIG_TYPE_BOOL},


{"RAND_SEED", "0",
"Seed for random sharp job ID. (0 - use default).",
ucc_offsetof(ucc_tl_sharp_context_config_t, rand_seed),
Expand Down
36 changes: 20 additions & 16 deletions src/components/tl/sharp/tl_sharp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,31 @@
#include "core/ucc_ee.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"

#include <sharp/api/version.h>
#include <sharp/api/sharp_coll.h>

enum sharp_datatype ucc_to_sharp_dtype[] = {
[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT16)] = SHARP_DTYPE_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT32)] = SHARP_DTYPE_INT,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT64)] = SHARP_DTYPE_LONG,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT128)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT16)] = SHARP_DTYPE_UNSIGNED_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT32)] = SHARP_DTYPE_UNSIGNED,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT64)] = SHARP_DTYPE_UNSIGNED_LONG,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT128)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT16)] = SHARP_DTYPE_FLOAT_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT32)] = SHARP_DTYPE_FLOAT,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT64)] = SHARP_DTYPE_DOUBLE,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT128)] = SHARP_DTYPE_NULL,
// TODO in hpcx-2.11 add UCC_DT_BFLOAT16
[UCC_DT_PREDEFINED_ID(UCC_DT_INT16)] = SHARP_DTYPE_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT32)] = SHARP_DTYPE_INT,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT64)] = SHARP_DTYPE_LONG,
[UCC_DT_PREDEFINED_ID(UCC_DT_INT128)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT16)] = SHARP_DTYPE_UNSIGNED_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT32)] = SHARP_DTYPE_UNSIGNED,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT64)] = SHARP_DTYPE_UNSIGNED_LONG,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT128)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT16)] = SHARP_DTYPE_FLOAT_SHORT,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT32)] = SHARP_DTYPE_FLOAT,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT64)] = SHARP_DTYPE_DOUBLE,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT128)] = SHARP_DTYPE_NULL,
#if SHARP_API > SHARP_VERSION(3, 0)
[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_UNKNOWN,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_UNKNOWN,
[UCC_DT_PREDEFINED_ID(UCC_DT_BFLOAT16)] = SHARP_DTYPE_UNKNOWN,
#else
[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_BFLOAT16)] = SHARP_DTYPE_NULL,
#endif
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT32_COMPLEX)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT64_COMPLEX)] = SHARP_DTYPE_NULL,
[UCC_DT_PREDEFINED_ID(UCC_DT_FLOAT128_COMPLEX)] = SHARP_DTYPE_NULL,
Expand Down
7 changes: 6 additions & 1 deletion src/components/tl/sharp/tl_sharp_coll.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -9,6 +9,11 @@

#include "tl_sharp.h"

/* need to query for datatype support at runtime */
#define SHARP_DTYPE_UNKNOWN -1

extern enum sharp_datatype ucc_to_sharp_dtype[];

ucc_status_t ucc_tl_sharp_allreduce_init(ucc_tl_sharp_task_t *task);

ucc_status_t ucc_tl_sharp_barrier_init(ucc_tl_sharp_task_t *task);
Expand Down
56 changes: 50 additions & 6 deletions src/components/tl/sharp/tl_sharp_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
#include "components/mc/ucc_mc.h"
#include "core/ucc_ee.h"
#include "coll_score/ucc_coll_score.h"
#include <sharp/api/version.h>

UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
ucc_tl_sharp_context_t *ctx =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment

ucc_derived_of(tl_context, ucc_tl_sharp_context_t);
struct sharp_coll_context *sharp_ctx = ctx->sharp_context;
struct sharp_coll_context *sharp_ctx = ctx->sharp_context;
struct sharp_coll_comm_init_spec comm_spec;
int ret;
ucc_status_t status;
int ret;
ucc_status_t status;

if (!(params->params.mask & UCC_TEAM_PARAM_FIELD_OOB)) {
tl_debug(ctx->super.super.lib, "team OOB required for sharp team");
Expand Down Expand Up @@ -64,6 +65,48 @@ UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
self->rcache = ctx->rcache;
}

#if SHARP_API > SHARP_VERSION(3, 0)
if ((ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] ==
SHARP_DTYPE_UNKNOWN) ||
(ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(SHARP_DTYPE_UINT8)] ==
SHARP_DTYPE_UNKNOWN) ||
(ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(SHARP_DTYPE_BFLOAT16)] ==
SHARP_DTYPE_UNKNOWN)) {
struct sharp_coll_caps sharp_caps;
ret = sharp_coll_caps_query(sharp_ctx, &sharp_caps);
if (ret < 0) {
tl_error(ctx->super.super.lib, "sharp_coll_caps_query failed: %s(%d)",
sharp_coll_strerror(ret), ret);
goto cleanup;
}

if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_INT8)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_INT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_INT8;
} else {
tl_debug(ctx->super.super.lib, "disabling support for UCC_DT_INT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_INT8)] = SHARP_DTYPE_NULL;
}

if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_UINT8)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_UINT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_UINT8;
} else {
tl_debug(ctx->super.super.lib, "disabling support for UCC_DT_UINT8");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_UINT8)] = SHARP_DTYPE_NULL;
}


if (sharp_caps.support_mask.dtypes & UCC_BIT(SHARP_DTYPE_BFLOAT16)) {
tl_debug(ctx->super.super.lib, "enabling support for UCC_DT_BFLOAT16");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_BFLOAT16)] = UCC_DT_BFLOAT16;
} else {
tl_debug(ctx->super.super.lib, "disabling support for UCC_DT_BFLOAT16");
ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(UCC_DT_BFLOAT16)] = SHARP_DTYPE_NULL;
}
}
#endif

comm_spec.rank = UCC_TL_TEAM_RANK(self);
comm_spec.size = UCC_TL_TEAM_SIZE(self);
comm_spec.group_world_ranks = NULL;
Expand All @@ -72,8 +115,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
ret = sharp_coll_comm_init(sharp_ctx,
&comm_spec, &self->sharp_comm);
if (ret < 0) {
tl_error(ctx->super.super.lib,
"sharp group create failed:%s(%d)",
tl_error(ctx->super.super.lib, "sharp group create failed:%s(%d)",
sharp_coll_strerror(ret), ret);
status = UCC_ERR_NO_RESOURCE;
goto cleanup;
Expand All @@ -88,7 +130,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_sharp_team_t, ucc_base_context_t *tl_context,
}
if (self->sharp_context) {
ucc_context_progress_deregister(
tl_context->ucc_context, (ucc_context_progress_fn_t)sharp_coll_progress, self->sharp_context);
tl_context->ucc_context,
(ucc_context_progress_fn_t)sharp_coll_progress,
self->sharp_context);
sharp_coll_finalize(self->sharp_context);
}
}
Expand Down