Skip to content

Commit

Permalink
Removed distributed communication class
Browse files Browse the repository at this point in the history
Added missing parameter to test

Fixed formatting

docs: fix build issues and add sub-sections (#69)
  • Loading branch information
Sathwik Yanamaddi committed Apr 12, 2024
1 parent 9f3931f commit 43f159a
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 29 deletions.
15 changes: 0 additions & 15 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,6 @@
import numpy as np


class DistributedEnvironment:
def __init__(self):
self.world_size = int(os.environ["SLURM_NTASKS"])
self.local_rank = int(os.environ["SLURM_PROCID"])

def get_world_size(self):
return self.world_size

def get_rank(self):
return self.local_rank


class communication_handle:
"""
Communnication handle for point-to-point(MPI) and collective
Expand Down Expand Up @@ -61,9 +49,6 @@ def __init__(
config.device = device
if config.device == "cpu":
self.backend = "gloo"
env = DistributedEnvironment()
self.world_rank = env.get_rank()
self.world_size = env.get_world_size()
else:
self.backend = "nccl"

Expand Down
5 changes: 5 additions & 0 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias, set_devic
@pytest.mark.parametrize("easy_tp", [False, True])
@pytest.mark.parametrize("clip_grad_norm", [-1, 1e-3])
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("set_device", ["cuda", "cpu"])
def test_bw_pass(
G_intra_r,
G_intra_c,
Expand All @@ -92,6 +93,7 @@ def test_bw_pass(
easy_tp,
clip_grad_norm,
bias,
set_device,
):
# These tests are in fp-32
torch.manual_seed(42)
Expand All @@ -101,6 +103,9 @@ def test_bw_pass(
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
G_intra_d=G_intra_d,
mixed_precision=False,
fp16_allreduce=False,
device=set_device,
)
X = torch.randn(B, H).cuda() * 0.01
Y_grad = torch.randn(B, H).cuda() * 0.01
Expand Down
16 changes: 16 additions & 0 deletions docs/axonn_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# The name of the Pygments (syntax highlighting) style to use.
from pygments.styles.default import DefaultStyle
from pygments.token import Generic


# modifications to the default style
class AxonnStyle(DefaultStyle):
styles = DefaultStyle.styles.copy()
background_color = "#f4f4f8"
styles[Generic.Output] = "#355"
styles[Generic.Prompt] = "bold #346ec9"
17 changes: 3 additions & 14 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
import os
import sys

# sys.path.insert(0, os.path.abspath('.'))

# The name of the Pygments (syntax highlighting) style to use.
from pygments.styles.default import DefaultStyle
from pygments.token import Generic

import pkg_resources


Expand Down Expand Up @@ -60,21 +56,14 @@
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# modifications to the default style
class AxonnStyle(DefaultStyle):
styles = DefaultStyle.styles.copy()
background_color = "#f4f4f8"
styles[Generic.Output] = "#355"
styles[Generic.Prompt] = "bold #346ec9"


dist = pkg_resources.Distribution(__file__)
sys.path.append(".") # make 'conf' module findable
ep = pkg_resources.EntryPoint.parse("axonn = conf:AxonnStyle", dist=dist)
dist._ep_map = {"pygments.styles": {"plugin1": ep}}
pkg_resources.working_set.add(dist)

pygments_style = "axonn"
sys.path.insert(0, os.path.abspath("."))
pygments_style = "axonn_style.AxonnStyle"


# -- Options for HTML output -------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

********
Examples
********

Training
============



Fine-tuning
===========



Inference
=========
3 changes: 3 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ AxoNN is a parallel framework for training deep neural networks.
:caption: User Docs

getting_started
user_guide
examples



##################
Expand Down
28 changes: 28 additions & 0 deletions docs/user_guide.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
**********
User Guide
**********

Initializing AxoNN
==================

Tensor with Easy API
====================

Tensor using Advanced API
=====================================

Combining Tensor in AxoNN with PyTorch DDP
==========================================

Integration with other Parallel APIs
====================================

Huggingface
-----------

Pipelining in AxoNN
===================

Combining Pipelining in AxoNN with Data Parallelism
===================================================

0 comments on commit 43f159a

Please sign in to comment.