Skip to content

Commit

Permalink
chore: remove models and model from input and output identifier
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com>
  • Loading branch information
ThibaultFy committed Jun 26, 2023
1 parent 485d1e2 commit f290240
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
13 changes: 5 additions & 8 deletions tests/workflows/mnist-fedavg/assets/aggregate_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch
import torch.nn.functional as F

from substratest.fl_interface import InputIdentifiers
from substratest.fl_interface import OutputIdentifiers

_INPUT_SAMPLE_SIZE = 21632
_OUT_SAMPLE_SIZE = 10
_NB_CHANNELS = 32
Expand Down Expand Up @@ -37,7 +34,7 @@ def forward(self, x):
def aggregate(inputs, outputs, task_properties):
# get layers
inmodels = []
for m_path in inputs[InputIdentifiers.shared]:
for m_path in inputs["models"]:
inmodels.append(load_model(m_path))

model = inmodels[0]
Expand All @@ -53,15 +50,15 @@ def aggregate(inputs, outputs, task_properties):

model.load_state_dict(model_state_dict)

save_model(model, outputs[InputIdentifiers.shared])
save_model(model, outputs["shared"])


@tools.register
def predict(inputs, outputs, task_properties):
X = inputs[InputIdentifiers.datasamples]["X"]
X = inputs["datasamples"]["X"]
X = torch.FloatTensor(X)

model = load_model(inputs[InputIdentifiers.shared])
model = load_model(inputs["shared"])
model.eval()
# add the context manager to reduce computation overhead
with torch.no_grad():
Expand All @@ -70,7 +67,7 @@ def predict(inputs, outputs, task_properties):
y_pred = y_pred.data.cpu().numpy()
pred = np.argmax(y_pred, axis=1)

save_predictions(pred, outputs[OutputIdentifiers.predictions])
save_predictions(pred, outputs["predictions"])


def load_model(path):
Expand Down
21 changes: 9 additions & 12 deletions tests/workflows/mnist-fedavg/assets/composite_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch
import torch.nn.functional as F

from substratest.fl_interface import InputIdentifiers
from substratest.fl_interface import OutputIdentifiers

_INPUT_SAMPLE_SIZE = 21632
_OUT_SAMPLE_SIZE = 10
_NB_CHANNELS = 32
Expand Down Expand Up @@ -67,14 +64,14 @@ def train(inputs, outputs, task_properties):
torch.manual_seed(_SEED) # initialize model weights
torch.use_deterministic_algorithms(True)

head_model_path = inputs.get(InputIdentifiers.local)
trunk_model_path = inputs.get(InputIdentifiers.shared)
head_model_path = inputs.get("local")
trunk_model_path = inputs.get("shared")

head_model = load_head_model(head_model_path) if head_model_path is not None else torch.nn.Module()
trunk_model = load_trunk_model(trunk_model_path) if trunk_model_path is not None else Network()

X = inputs[InputIdentifiers.datasamples]["X"]
y = inputs[InputIdentifiers.datasamples]["y"]
X = inputs["datasamples"]["X"]
y = inputs["datasamples"]["y"]
rank = task_properties["rank"]

_fit(
Expand All @@ -86,15 +83,15 @@ def train(inputs, outputs, task_properties):
rank=rank,
)

save_head_model(head_model, outputs[OutputIdentifiers.local])
save_trunk_model(trunk_model, outputs[OutputIdentifiers.shared])
save_head_model(head_model, outputs["local"])
save_trunk_model(trunk_model, outputs["shared"])


@tools.register
def predict(inputs, outputs, task_properties):
trunk_model = load_trunk_model(inputs[OutputIdentifiers.shared])
trunk_model = load_trunk_model(inputs["shared"])

X = inputs[InputIdentifiers.datasamples]["X"]
X = inputs["datasamples"]["X"]
X = torch.FloatTensor(X)
trunk_model.eval()

Expand All @@ -105,7 +102,7 @@ def predict(inputs, outputs, task_properties):
y_pred = y_pred.data.cpu().numpy()
pred = np.argmax(y_pred, axis=1)

save_predictions(pred, outputs[OutputIdentifiers.predictions])
save_predictions(pred, outputs["predictions"])


def load_model(path):
Expand Down

0 comments on commit f290240

Please sign in to comment.