Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
76bd1ba
update
tchaton Nov 8, 2022
a753b8a
update
tchaton Nov 8, 2022
147e037
Merge branch 'master' into enable_debug_mode
tchaton Nov 8, 2022
5a8a779
Merge branch 'master' into enable_debug_mode
tchaton Nov 8, 2022
447adc2
update
tchaton Nov 8, 2022
5d9cd51
Merge branch 'enable_debug_mode' of https://github.com/Lightning-AI/l…
tchaton Nov 8, 2022
f63c47d
update
tchaton Nov 8, 2022
a67632c
update
tchaton Nov 8, 2022
50efaaf
update
tchaton Nov 8, 2022
ad64687
update
tchaton Nov 8, 2022
eedd770
Merge branch 'master' into enable_debug_mode
tchaton Nov 9, 2022
1ff6571
update
tchaton Nov 9, 2022
6cc05ce
Merge branch 'enable_debug_mode' of https://github.com/Lightning-AI/l…
tchaton Nov 9, 2022
b75524f
update
tchaton Nov 9, 2022
c63307e
update
tchaton Nov 9, 2022
34b893e
update
tchaton Nov 9, 2022
92c8ce6
update
tchaton Nov 9, 2022
83d303f
update
tchaton Nov 9, 2022
c0308c9
update
tchaton Nov 9, 2022
e0b406b
update
tchaton Nov 9, 2022
55d3406
update
tchaton Nov 9, 2022
2d924ab
update
tchaton Nov 9, 2022
3614a90
update
tchaton Nov 9, 2022
bc11d7b
update
tchaton Nov 9, 2022
756685e
update
tchaton Nov 9, 2022
17ace2d
Merge branch 'master' into enable_debug_mode_2
tchaton Nov 10, 2022
2a4a660
update
tchaton Nov 10, 2022
81888ef
update
tchaton Nov 10, 2022
2382cca
update
tchaton Nov 10, 2022
8c23ca3
update
tchaton Nov 10, 2022
324c9fe
update
tchaton Nov 10, 2022
434a20f
update
tchaton Nov 10, 2022
78c4146
update
tchaton Nov 10, 2022
806bd25
update
tchaton Nov 10, 2022
061fdac
update
tchaton Nov 10, 2022
831b19e
update
tchaton Nov 10, 2022
f2d6090
update
tchaton Nov 10, 2022
c8d727f
update
tchaton Nov 10, 2022
2f8cc1a
update
tchaton Nov 10, 2022
830ac24
update
tchaton Nov 10, 2022
e51d57e
update
tchaton Nov 10, 2022
1402f34
update
tchaton Nov 10, 2022
376a4a3
update
tchaton Nov 10, 2022
a14b275
Merge branch 'master' into enable_debug_mode_2
tchaton Nov 10, 2022
3d37fa5
Merge branch 'master' into enable_debug_mode_2
tchaton Nov 10, 2022
6856a13
Merge branch 'master' into enable_debug_mode_2
tchaton Nov 10, 2022
deafa68
Merge branch 'master' into enable_debug_mode_2
tchaton Nov 11, 2022
58afef9
Merge branch 'enable_debug_mode_2' of https://github.com/Lightning-AI…
tchaton Nov 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions examples/app_multi_node/train_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,26 @@


class LitePyTorchDistributed(L.LightningWork):
@staticmethod
def run():
# 1. Create LightningLite.
lite = LightningLite(strategy="ddp", precision=16)
def run(self):
# 1. Prepare the model
model = torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.ReLU(),
torch.nn.Linear(1, 1),
)

# 2. Prepare distributed model and optimizer.
model = torch.nn.Linear(32, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model, optimizer = lite.setup(model, optimizer)
# 2. Create LightningLite.
lite = LightningLite(strategy="ddp", precision=16)
model, optimizer = lite.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
criterion = torch.nn.MSELoss()

# 3. Train the model for 50 steps.
for step in range(50):
# 3. Train the model for 1000 steps.
for step in range(1000):
model.zero_grad()
x = torch.randn(64, 32).to(lite.device)
x = torch.tensor([0.8]).to(lite.device)
target = torch.tensor([1.0]).to(lite.device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
loss = criterion(output, target)
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
lite.backward(loss)
optimizer.step()
Expand Down
5 changes: 2 additions & 3 deletions examples/app_multi_node/train_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@


class LightningTrainerDistributed(L.LightningWork):
@staticmethod
def run():
def run(self):
model = BoringModel()
trainer = L.Trainer(
max_epochs=10,
max_steps=1000,
strategy="ddp",
)
trainer.fit(model)
Expand Down
27 changes: 13 additions & 14 deletions examples/app_multi_node/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,28 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no
init_method=f"tcp://{main_address}:{main_port}",
)

# 2. Prepare distributed model
model = torch.nn.Linear(32, 2)
# 2. Prepare the model
model = torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.ReLU(),
torch.nn.Linear(1, 1),
)

# 3. Setup distributed training
if torch.cuda.is_available():
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

model = model.to(device)
model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
model = DistributedDataParallel(model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None)

# 4. Prepare loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 5. Train the model for 50 steps.
for step in range(50):
# 5. Train the model for 1000 steps.
for step in range(1000):
model.zero_grad()
x = torch.randn(64, 32).to(device)
x = torch.tensor([0.8]).to(device)
target = torch.tensor([1.0]).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
loss = criterion(output, target)
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()
Expand Down
33 changes: 16 additions & 17 deletions examples/app_multi_node/train_pytorch_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,37 @@


class PyTorchDistributed(L.LightningWork):

# Note: Only staticmethod are support for now with `PyTorchSpawnMultiNode`
@staticmethod
def run(
self,
world_size: int,
node_rank: int,
global_rank: str,
local_rank: int,
):
# 1. Prepare distributed model
model = torch.nn.Linear(32, 2)
# 1. Prepare the model
model = torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.ReLU(),
torch.nn.Linear(1, 1),
)

# 2. Setup distributed training
if torch.cuda.is_available():
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")

model = model.to(device)
model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
model = DistributedDataParallel(
model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None
)

# 3. Prepare loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. Train the model for 50 steps.
for step in range(50):
# 4. Train the model for 1000 steps.
for step in range(1000):
model.zero_grad()
x = torch.randn(64, 32).to(device)
x = torch.tensor([0.8]).to(device)
target = torch.tensor([1.0]).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
loss = criterion(output, target)
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()
Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `bi-directional` delta updates between the flow and the works ([#15582](https://github.com/Lightning-AI/lightning/pull/15582))

- Enabled MultiNode Components to support state broadcasting ([#15607](https://github.com/Lightning-AI/lightning/pull/15607))


### Changed

Expand Down
6 changes: 0 additions & 6 deletions src/lightning_app/components/multi_node/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import is_static_method
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
from lightning_app.utilities.tracer import Tracer

Expand Down Expand Up @@ -82,11 +81,6 @@ def __init__(
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _LiteWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
"HINT: Remove `self` and add staticmethod decorator."
)

# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
Expand Down
42 changes: 31 additions & 11 deletions src/lightning_app/components/multi_node/pytorch_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing_extensions import Protocol, runtime_checkable

from lightning_app.components.multi_node.base import MultiNode
from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import is_static_method
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
from lightning_app.utilities.proxies import WorkRunExecutor
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver


@runtime_checkable
Expand All @@ -22,6 +22,9 @@ def run(


class _PyTorchSpawnRunExecutor(WorkRunExecutor):

enable_start_observer: bool = False

def __call__(
self,
main_address: str,
Expand All @@ -31,10 +34,31 @@ def __call__(
):
import torch

nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
torch.multiprocessing.spawn(
self.run, args=(self.work_run, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
)
with self.enable_spawn():
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, main_address, main_port, num_nodes, node_rank, nprocs),
nprocs=nprocs,
)

@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, *args, **kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)

state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)

cls.run(local_rank, unwrap(work.run), *args, **kwargs)

if local_rank == 0:
state_observer.join(0)

@staticmethod
def run(
Expand All @@ -46,6 +70,7 @@ def run(
node_rank: int,
nprocs: int,
):

import torch

# 1. Setting distributed environment
Expand Down Expand Up @@ -76,11 +101,6 @@ def __init__(
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _PyTorchSpawnWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
"HINT: Remove `self` and add staticmethod decorator."
)

# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
Expand Down
6 changes: 0 additions & 6 deletions src/lightning_app/components/multi_node/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import is_static_method
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
from lightning_app.utilities.tracer import Tracer

Expand Down Expand Up @@ -81,11 +80,6 @@ def __init__(
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
"HINT: Remove `self` and add staticmethod decorator."
)

# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
Expand Down
33 changes: 29 additions & 4 deletions src/lightning_app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def __init__(
"""
if name is None:
raise ValueError("You must specify a name for the queue")
host = host or REDIS_HOST
port = port or REDIS_PORT
password = password or REDIS_PASSWORD
self.host = host or REDIS_HOST
self.port = port or REDIS_PORT
self.password = password or REDIS_PASSWORD
self.name = name
self.default_timeout = default_timeout
self.redis = redis.Redis(host=host, port=port, password=password)
self.redis = redis.Redis(host=self.host, port=self.port, password=self.password)

def put(self, item: Any) -> None:
from lightning_app import LightningWork
Expand Down Expand Up @@ -329,6 +329,20 @@ def is_running(self) -> bool:
except redis.exceptions.ConnectionError:
return False

def to_dict(self):
return {
"type": "redis",
"name": self.name,
"default_timeout": self.default_timeout,
"host": self.host,
"port": self.port,
"password": self.password,
}

@classmethod
def from_dict(cls, state):
return cls(**state)


class HTTPQueue(BaseQueue):
def __init__(self, name: str, default_timeout: float):
Expand Down Expand Up @@ -414,6 +428,17 @@ def _split_app_id_and_queue_name(queue_name):
app_id, queue_name = queue_name.split("_", 1)
return app_id, queue_name

def to_dict(self):
return {
"type": "http",
"name": self.name,
"default_timeout": self.default_timeout,
}

@classmethod
def from_dict(cls, state):
return cls(**state)


def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None:
if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists():
Expand Down
Loading