Skip to content
This repository has been archived by the owner on Dec 4, 2024. It is now read-only.

Commit

Permalink
use an enums for modes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Nov 17, 2023
1 parent d27e695 commit 8eaf2d6
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 34 deletions.
Empty file added elephas/enums/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions elephas/enums/modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys

if sys.version_info.minor < 11:
# Devs using version < 3.11 can use the str enum mixin
from enum import Enum


class Mode(str, Enum):
SYNCHRONOUS: str = 'synchronous'
ASYNCHRONOUS: str = 'asynchronous'
HOGWILD: str = 'hogwild'
else:
# Devs using version >= 3.11 can use the strenum builtin,
# but that breaks the mixin implementation
# https://github.com/python/cpython/issues/100458
from enum import StrEnum, auto


class Mode(StrEnum):
SYNCHRONOUS = auto()
ASYNCHRONOUS = auto()
HOGWILD = auto()
3 changes: 2 additions & 1 deletion elephas/parameter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from multiprocessing import Process
from tensorflow.keras.models import Model

from elephas.enums.modes import Mode
from elephas.utils.sockets import determine_master
from elephas.utils.sockets import receive, send
from elephas.utils.serialization import dict_to_model
Expand Down Expand Up @@ -44,7 +45,7 @@ def make_write_threadsafe_if_necessary(self, func):
return self.make_threadsafe_if_necessary(func, self.lock.acquire_write)

def make_threadsafe_if_necessary(self, func, lock_aquire_callable):
if self.mode == 'asynchronous':
if self.mode == Mode.ASYNCHRONOUS:
@wraps(func)
def wrapper(*args, **kwargs):
lock_aquire_callable()
Expand Down
17 changes: 9 additions & 8 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tensorflow.keras.optimizers import get as get_optimizer
from tensorflow.keras.optimizers import serialize as serialize_optimizer, deserialize as deserialize_optimizer

from .enums.modes import Mode
from .mllib import to_matrix, from_matrix, to_vector, from_vector
from .parameter.factory import ClientServerFactory
from .utils import lp_to_simple_rdd, to_simple_rdd
Expand All @@ -27,7 +28,7 @@

class SparkModel:

def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http', num_workers=None,
def __init__(self, model, mode=Mode.ASYNCHRONOUS, frequency='epoch', parameter_server_mode='http', num_workers=None,
custom_objects=None, batch_size=32, port=4000, *args, **kwargs):
"""SparkModel
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_serv
self.kwargs = kwargs

self.serialized_model = model_to_dict(model)
if self.mode != 'synchronous':
if self.mode != Mode.SYNCHRONOUS:
factory = ClientServerFactory.get_factory(self.parameter_server_mode)
self.parameter_server = factory.create_server(self.serialized_model, self.port, self.mode,
custom_objects=self.custom_objects)
Expand Down Expand Up @@ -182,7 +183,7 @@ def fit(self, rdd: RDD, **kwargs):
if self.num_workers:
rdd = rdd.repartition(self.num_workers)

if self.mode in ['asynchronous', 'synchronous', 'hogwild']:
if self.mode in [mode for mode in Mode]:
self._fit(rdd, **kwargs)
else:
raise ValueError(
Expand All @@ -193,7 +194,7 @@ def _fit(self, rdd: RDD, **kwargs):
self._master_network.compile(optimizer=get_optimizer(self.master_optimizer),
loss=self.master_loss,
metrics=self.master_metrics)
if self.mode in ['asynchronous', 'hogwild']:
if self.mode in [Mode.ASYNCHRONOUS, Mode.HOGWILD]:
self.start_server()
train_config = kwargs
freq = self.frequency
Expand All @@ -206,15 +207,15 @@ def _fit(self, rdd: RDD, **kwargs):
init = self._master_network.get_weights()
parameters = rdd.context.broadcast(init)

if self.mode in ['asynchronous', 'hogwild']:
if self.mode in [Mode.ASYNCHRONOUS, Mode.HOGWILD]:
print('>>> Initialize workers')
worker = AsynchronousSparkWorker(
model_json, parameters, self.client, train_config, freq, optimizer, loss, metrics, custom)
print('>>> Distribute load')
rdd.mapPartitions(worker.train).collect()
print('>>> Async training complete.')
new_parameters = self.client.get_parameters()
elif self.mode == 'synchronous':
elif self.mode == Mode.SYNCHRONOUS:
worker = SparkWorker(model_json, parameters, train_config,
optimizer, loss, metrics, custom)
training_outcomes = rdd.mapPartitions(worker.train).collect()
Expand All @@ -229,7 +230,7 @@ def _fit(self, rdd: RDD, **kwargs):
else:
raise ValueError("Unsupported mode {}".format(self.mode))
self._master_network.set_weights(new_parameters)
if self.mode in ['asynchronous', 'hogwild']:
if self.mode in [Mode.ASYNCHRONOUS, Mode.HOGWILD]:
self.stop_server()

def _predict(self, rdd: RDD) -> List[np.ndarray]:
Expand Down Expand Up @@ -310,7 +311,7 @@ def _evaluate(model, optimizer, loss: Callable[[tf.Tensor, tf.Tensor], tf.Tensor

class SparkMLlibModel(SparkModel):

def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
def __init__(self, model, mode=Mode.ASYNCHRONOUS, frequency='epoch', parameter_server_mode='http',
num_workers=4, elephas_optimizer=None, custom_objects=None, batch_size=32, port=4000, *args, **kwargs):
"""SparkMLlibModel
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers.legacy import SGD

from elephas.enums.modes import Mode
from elephas.spark_model import SparkModel
from elephas.utils import to_simple_rdd


@pytest.mark.parametrize('mode', ['synchronous', 'asynchronous', 'hogwild'])
@pytest.mark.parametrize('mode', [mode for mode in Mode])
def test_training_custom_activation(mode, spark_context):
def custom_activation(x):
return sigmoid(x) + 1
Expand Down
43 changes: 22 additions & 21 deletions tests/integration/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from tensorflow.keras.optimizers.legacy import SGD

from elephas.enums.modes import Mode
from elephas.spark_model import SparkModel
from elephas.utils.rdd_utils import to_simple_rdd

Expand All @@ -16,16 +17,16 @@ def _generate_port_number(port=3000, _count=count(1)):
# enumerate possible combinations for training mode and parameter server for a classification model while also
# validatiing multiple workers for repartitioning
@pytest.mark.parametrize('mode,parameter_server_mode,num_workers',
[('synchronous', None, None),
('synchronous', None, 2),
('asynchronous', 'http', None),
('asynchronous', 'http', 2),
('asynchronous', 'socket', None),
('asynchronous', 'socket', 2),
('hogwild', 'http', None),
('hogwild', 'http', 2),
('hogwild', 'socket', None),
('hogwild', 'socket', 2)])
[(Mode.SYNCHRONOUS, None, None),
(Mode.SYNCHRONOUS, None, 2),
(Mode.ASYNCHRONOUS, 'http', None),
(Mode.ASYNCHRONOUS, 'http', 2),
(Mode.ASYNCHRONOUS, 'socket', None),
(Mode.ASYNCHRONOUS, 'socket', 2),
(Mode.HOGWILD, 'http', None),
(Mode.HOGWILD, 'http', 2),
(Mode.HOGWILD, 'socket', None),
(Mode.HOGWILD, 'socket', 2)])
def test_training_classification(spark_context, mode, parameter_server_mode, num_workers, mnist_data, classification_model):
# Define basic parameters
batch_size = 64
Expand Down Expand Up @@ -70,16 +71,16 @@ def test_training_classification(spark_context, mode, parameter_server_mode, num
# enumerate possible combinations for training mode and parameter server for a regression model while also validating
# multiple workers for repartitioning
@pytest.mark.parametrize('mode,parameter_server_mode,num_workers',
[('synchronous', None, None),
('synchronous', None, 2),
('asynchronous', 'http', None),
('asynchronous', 'http', 2),
('asynchronous', 'socket', None),
('asynchronous', 'socket', 2),
('hogwild', 'http', None),
('hogwild', 'http', 2),
('hogwild', 'socket', None),
('hogwild', 'socket', 2)])
[(Mode.SYNCHRONOUS, None, None),
(Mode.SYNCHRONOUS, None, 2),
(Mode.ASYNCHRONOUS, 'http', None),
(Mode.ASYNCHRONOUS, 'http', 2),
(Mode.ASYNCHRONOUS, 'socket', None),
(Mode.ASYNCHRONOUS, 'socket', 2),
(Mode.HOGWILD, 'http', None),
(Mode.HOGWILD, 'http', 2),
(Mode.HOGWILD, 'socket', None),
(Mode.HOGWILD, 'socket', 2)])
def test_training_regression(spark_context, mode, parameter_server_mode, num_workers, boston_housing_dataset,
regression_model):
x_train, y_train, x_test, y_test = boston_housing_dataset
Expand Down Expand Up @@ -124,7 +125,7 @@ def test_training_regression_no_metrics(spark_context, boston_housing_dataset, r
epochs = 1
sgd = SGD(lr=0.0000001)
regression_model.compile(sgd, 'mse')
spark_model = SparkModel(regression_model, frequency='epoch', mode='synchronous', port=_generate_port_number())
spark_model = SparkModel(regression_model, frequency='epoch', mode=Mode.SYNCHRONOUS, port=_generate_port_number())

# Train Spark model
spark_model.fit(rdd, epochs=epochs, batch_size=batch_size, verbose=0, validation_split=0.1)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_mllib_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from tensorflow.keras.optimizers.legacy import RMSprop

from elephas.enums.modes import Mode
from elephas.spark_model import SparkMLlibModel, load_spark_model
from elephas.utils.rdd_utils import to_labeled_point

Expand All @@ -16,7 +17,7 @@ def test_serialization(classification_model):
rms = RMSprop()
classification_model.compile(rms, 'categorical_crossentropy', ['acc'])
spark_model = SparkMLlibModel(
classification_model, frequency='epoch', mode='synchronous', num_workers=2)
classification_model, frequency='epoch', mode=Mode.SYNCHRONOUS, num_workers=2)
spark_model.save("test.h5")
loaded_model = load_spark_model("test.h5")
assert loaded_model.master_network.to_json()
Expand All @@ -34,7 +35,7 @@ def test_mllib_model(spark_context, classification_model, mnist_data):

# Initialize SparkModel from tensorflow.keras model and Spark context
spark_model = SparkMLlibModel(
model=classification_model, frequency='epoch', mode='synchronous')
model=classification_model, frequency='epoch', mode=Mode.SYNCHRONOUS)

# Train Spark model
spark_model.fit(lp_rdd, epochs=5, batch_size=32, verbose=0,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input

from elephas.enums.modes import Mode
from elephas.spark_model import SparkModel


def test_sequential_serialization(spark_context, classification_model):
classification_model.compile(
optimizer="sgd", loss="categorical_crossentropy", metrics=["acc"])
spark_model = SparkModel(classification_model, frequency='epoch', mode='synchronous')
spark_model = SparkModel(classification_model, frequency='epoch', mode=Mode.SYNCHRONOUS)
spark_model.save("elephas_sequential.h5")


Expand Down

0 comments on commit 8eaf2d6

Please sign in to comment.