Skip to content

Commit

Permalink
TL/MLX5: adding mcast interface
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jun 13, 2023
1 parent 6c2a4a8 commit ebc38fc
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 36 deletions.
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@

if TL_MLX5_ENABLED

mcast = \
mcast/tl_mlx5_mcast_context.c \
mcast/tl_mlx5_mcast.h \
mcast/tl_mlx5_mcast_coll.c \
mcast/tl_mlx5_mcast_coll.h \
mcast/tl_mlx5_mcast_team.c

sources = \
tl_mlx5.h \
tl_mlx5.c \
tl_mlx5_lib.c \
tl_mlx5_context.c \
tl_mlx5_team.c \
tl_mlx5_coll.h \
tl_mlx5_coll.c \
$(mcast) \
tl_mlx5_ib.h \
tl_mlx5_ib.c \
tl_mlx5_wqe.h \
Expand Down
72 changes: 72 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef UCC_MCAST_H
#define UCC_MCAST_H

#include <infiniband/ib.h>
#include <infiniband/umad.h>
#include <infiniband/verbs.h>
#include <rdma/rdma_verbs.h>
#include "utils/ucc_list.h"
#include "utils/ucc_mpool.h"
#include "components/tl/ucc_tl.h"
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"


#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true

typedef struct mcast_coll_comm_init_spec {
} mcast_coll_comm_init_spec_t;

typedef struct ucc_tl_mlx5_mcast_lib {
} ucc_tl_mlx5_mcast_lib_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *,
const ucc_base_config_t *);

typedef struct ucc_tl_mlx5_mcast_ctx_params {
} ucc_tl_mlx5_mcast_ctx_params_t;

typedef struct mcast_coll_context_t {
} mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_context_t {
} ucc_tl_mlx5_mcast_context_t;


typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */
} mcast_coll_comm_t;

typedef struct ucc_tl_mlx5_mcast_team {
void *mcast_comm;
} ucc_tl_mlx5_mcast_team_t;

typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */
} ucc_tl_mlx5_mcast_coll_req_t;

#define TASK_TEAM_MCAST(_task) \
(ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t))
#define TASK_CTX_MCAST(_task) \
(ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t))
#define TASK_LIB_MCAST(_task) \
(ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t))
#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
const ucc_base_team_params_t *params,
mcast_coll_comm_init_spec_t *mcast_conf);

ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx,
ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf);

#endif
74 changes: 74 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "tl_mlx5_coll.h"

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req)
{
return UCC_ERR_NOT_SUPPORTED;
}

ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr,
mcast_coll_comm_t *comm,
int is_blocking,
ucc_tl_mlx5_mcast_coll_req_t **task_req_handle)
{
return UCC_ERR_NOT_SUPPORTED;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task);
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast;
ucc_coll_args_t *args = &TASK_ARGS_MCAST(task);
ucc_datatype_t dt = args->src.info.datatype;
size_t count = args->src.info.count;
ucc_rank_t root = args->root;
ucc_status_t status = UCC_OK;
size_t data_size = ucc_dt_size(dt) * count;
void *buf = args->src.info.buffer;
mcast_coll_comm_t *comm = team->mcast_comm;

task->bcast_mcast.req_handle = NULL;

status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm,
UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle);
if (UCC_OK != status && UCC_INPROGRESS != status) {
tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status);
coll_task->status = status;
return ucc_task_complete(coll_task);
}

coll_task->status = status;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);
}

void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle;

if (req != NULL) {
status = ucc_tl_mlx5_mcast_test(req);
if (UCC_OK == status) {
coll_task->status = UCC_OK;
ucc_free(req);
task->bcast_mcast.req_handle = NULL;
}
}
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{

task->super.post = ucc_tl_mlx5_mcast_bcast_start;
task->super.progress = ucc_tl_mlx5_mcast_collective_progress;

return UCC_OK;
}
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef UCC_TL_MLX5_MCAST_COLL_H_
#define UCC_TL_MLX5_MCAST_COLL_H_

#include "tl_mlx5_mcast.h"
#include "tl_mlx5_coll.h"

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

#endif
18 changes: 18 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include <inttypes.h>
#include "tl_mlx5_mcast.h"
#include "utils/arch/cpu.h"
#include <ucs/sys/string.h>
#include "src/core/ucc_service_coll.h"
#include "tl_mlx5.h"

ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context,
ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf)
{
return UCC_ERR_NOT_SUPPORTED;
}
20 changes: 20 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/


#include "tl_mlx5.h"
#include "tl_mlx5_mcast_coll.h"
#include "coll_score/ucc_coll_score.h"

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
const ucc_base_team_params_t *params,
mcast_coll_comm_init_spec_t *mcast_conf)
{
return UCC_ERR_NOT_SUPPORTED;
}

49 changes: 27 additions & 22 deletions src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <infiniband/verbs.h>
#include <infiniband/mlx5dv.h>
#include "utils/arch/cpu.h"
#include "mcast/tl_mlx5_mcast.h"

#ifndef UCC_TL_MLX5_DEFAULT_SCORE
#define UCC_TL_MLX5_DEFAULT_SCORE 1
Expand Down Expand Up @@ -48,19 +49,21 @@ typedef struct ucc_tl_mlx5_ib_qp_conf {
} ucc_tl_mlx5_ib_qp_conf_t;

typedef struct ucc_tl_mlx5_lib_config {
ucc_tl_lib_config_t super;
int block_size;
int num_dci_qps;
int dc_threshold;
size_t dm_buf_size;
unsigned long dm_buf_num;
int dm_host;
ucc_tl_mlx5_ib_qp_conf_t qp_conf;
ucc_tl_lib_config_t super;
int block_size;
int num_dci_qps;
int dc_threshold;
size_t dm_buf_size;
unsigned long dm_buf_num;
int dm_host;
ucc_tl_mlx5_ib_qp_conf_t qp_conf;
mcast_coll_comm_init_spec_t mcast_conf;
} ucc_tl_mlx5_lib_config_t;

typedef struct ucc_tl_mlx5_context_config {
ucc_tl_context_config_t super;
ucs_config_names_array_t devices;
ucc_tl_context_config_t super;
ucs_config_names_array_t devices;
ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf;
} ucc_tl_mlx5_context_config_t;

typedef struct ucc_tl_mlx5_lib {
Expand All @@ -80,6 +83,7 @@ typedef struct ucc_tl_mlx5_context {
int is_imported;
int ib_port;
ucc_mpool_t req_mp;
ucc_tl_mlx5_mcast_context_t mcast;
} ucc_tl_mlx5_context_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t *,
const ucc_base_config_t *);
Expand All @@ -100,17 +104,18 @@ typedef enum
} ucc_tl_mlx5_team_state_t;

typedef struct ucc_tl_mlx5_team {
ucc_tl_team_t super;
ucc_status_t status[2];
ucc_service_coll_req_t *scoll_req;
ucc_tl_mlx5_team_state_t state;
void *dm_offset;
ucc_mpool_t dm_pool;
struct ibv_dm *dm_ptr;
struct ibv_mr *dm_mr;
ucc_tl_mlx5_a2a_t *a2a;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
ucc_tl_team_t super;
ucc_status_t status[2];
ucc_service_coll_req_t *scoll_req;
ucc_tl_mlx5_team_state_t state;
void *dm_offset;
ucc_mpool_t dm_pool;
struct ibv_dm *dm_ptr;
struct ibv_mr *dm_mr;
ucc_tl_mlx5_a2a_t *a2a;
ucc_topo_t *topo;
ucc_ep_map_t ctx_map;
ucc_tl_mlx5_mcast_team_t *mcast;
} ucc_tl_mlx5_team_t;
UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);
Expand All @@ -126,7 +131,7 @@ typedef struct ucc_tl_mlx5_rcache_region {
ucc_tl_mlx5_reg_t reg;
} ucc_tl_mlx5_rcache_region_t;

#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL)
#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_BCAST)

#define UCC_TL_MLX5_TEAM_LIB(_team) \
(ucc_derived_of((_team)->super.super.context->lib, ucc_tl_mlx5_lib_t))
Expand Down
52 changes: 52 additions & 0 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "tl_mlx5_coll.h"
#include "mcast/tl_mlx5_mcast_coll.h"

static ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task);
UCC_TL_MLX5_PROFILE_REQUEST_FREE(task);
ucc_mpool_put(task);
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h)
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;

if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) {
return UCC_ERR_NOT_SUPPORTED;
}

task = ucc_tl_mlx5_get_task(coll_args, team);

if (ucc_unlikely(!task)) {
return UCC_ERR_NO_MEMORY;
}

task->super.finalize = ucc_tl_mlx5_task_finalize;

status = ucc_tl_mlx5_mcast_bcast_init(task);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

*task_h = &(task->super);

tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);

return UCC_OK;

free_task:
ucc_mpool_put(task);
return status;
}
Loading

0 comments on commit ebc38fc

Please sign in to comment.