Skip to content
This repository has been archived by the owner on Dec 4, 2024. It is now read-only.

Commit

Permalink
use decorator to reduce duplicated logic
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Nov 17, 2023
1 parent 25cca95 commit d27e695
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions elephas/parameter/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import pickle
import socket
from functools import wraps
from threading import Thread
from flask import Flask, request
from multiprocessing import Process
Expand All @@ -14,7 +15,7 @@
from elephas.utils import subtract_params


class BaseParameterServer(object):
class BaseParameterServer(abc.ABC):
"""BaseParameterServer
Parameter servers can be started and stopped. Server implementations have
Expand All @@ -25,18 +26,34 @@ def __init__(self, model: Model, port: int, mode: str, **kwargs):
self.port = port
self.mode = mode
self.master_network = dict_to_model(model, kwargs.get('custom_objects'))
self.lock = Lock()

@abc.abstractmethod
def start(self):
"""Start the parameter server instance.
"""
raise NotImplementedError

@abc.abstractmethod
def stop(self):
"""Terminate the parameter server instance.
"""
raise NotImplementedError
def make_read_threadsafe_if_necessary(self, func):
return self.make_threadsafe_if_necessary(func, self.lock.acquire_read)

def make_write_threadsafe_if_necessary(self, func):
return self.make_threadsafe_if_necessary(func, self.lock.acquire_write)

def make_threadsafe_if_necessary(self, func, lock_aquire_callable):
if self.mode == 'asynchronous':
@wraps(func)
def wrapper(*args, **kwargs):
lock_aquire_callable()
result = func(*args, **kwargs)
self.lock.release()
return result
return wrapper
else:
return func


class HttpServer(BaseParameterServer):
Expand Down Expand Up @@ -73,7 +90,6 @@ def __init__(self, model: Model, port: int, mode: str, **kwargs):
self.threaded = kwargs.get("threaded", True)
self.use_reloader = kwargs.get("use_reloader", False)

self.lock = Lock()
self.pickled_weights = None
self.weights = self.master_network.get_weights()

Expand Down Expand Up @@ -105,30 +121,23 @@ def home():
return 'Elephas'

@app.route('/parameters', methods=['GET'])
@self.make_read_threadsafe_if_necessary
def handle_get_parameters():
if self.mode == 'asynchronous':
self.lock.acquire_read()
self.pickled_weights = pickle.dumps(self.weights, -1)
pickled_weights = self.pickled_weights
if self.mode == 'asynchronous':
self.lock.release()
return pickled_weights

@app.route('/update', methods=['POST'])
@self.make_write_threadsafe_if_necessary
def handle_update_parameters():
delta = pickle.loads(request.data)
if self.mode == 'asynchronous':
self.lock.acquire_write()

if not self.master_network.built:
self.master_network.build()

# Just apply the gradient
weights_before = self.weights
self.weights = subtract_params(weights_before, delta)

if self.mode == 'asynchronous':
self.lock.release()
return 'Update done'

master_url = determine_master(self.port)
Expand Down Expand Up @@ -158,6 +167,8 @@ def __init__(self, model: Model, port: int, mode: str, **kwargs):
self.connections = []
self.lock = Lock()
self.thread = None
self.update_parameters = self.make_write_threadsafe_if_necessary(self.update_parameters)
self.get_parameters = self.make_read_threadsafe_if_necessary(self.get_parameters)

def start(self):
if self.thread is not None:
Expand Down Expand Up @@ -202,20 +213,12 @@ def update_parameters(self, conn):
data = receive(conn)
delta = data['delta']
weights = self.master_network.get_weights()
if self.mode == 'asynchronous':
self.lock.acquire_write()
# apply the gradient
self.master_network.set_weights(subtract_params(weights, delta))
if self.mode == 'asynchronous':
self.lock.release()

def get_parameters(self, conn):
if self.mode == 'asynchronous':
self.lock.acquire_read()
weights = self.master_network.get_weights()
send(conn, weights)
if self.mode == 'asynchronous':
self.lock.release()

def action_listener(self, conn):
while self.runs:
Expand Down

0 comments on commit d27e695

Please sign in to comment.