diff --git a/cli/src/cli.ts b/cli/src/cli.ts index d4b586e17..f22d600b9 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -1,9 +1,11 @@ import { Range } from 'immutable' +import type { Server } from 'node:http' import type { TrainerLog, data, Task } from '@epfml/discojs-core' -import { Disco, TrainingSchemes } from '@epfml/discojs-core' +import { Disco, TrainingSchemes, aggregator as aggregators, client as clients } from '@epfml/discojs-core' +import { getClient, startServer } from '@epfml/disco-server' -import { startServer, saveLog } from './utils' +import { saveLog } from './utils' import { getTaskData } from './data' import { args } from './args' @@ -15,10 +17,12 @@ console.log(infoText) console.log({ args }) -async function runUser (task: Task, url: URL, data: data.DataSplit): Promise { +async function runUser (task: Task, server: Server, data: data.DataSplit): Promise { + const client = await getClient(clients.federated.FederatedClient, server, task, new aggregators.MeanAggregator(TASK)) + // force the federated scheme const scheme = TrainingSchemes.FEDERATED - const disco = new Disco(task, { scheme, url }) + const disco = new Disco(task, { scheme, client }) await disco.fit(data) await disco.close() @@ -26,12 +30,12 @@ async function runUser (task: Task, url: URL, data: data.DataSplit): Promise { - const [server, serverUrl] = await startServer() + const server = await startServer() const data = await getTaskData(TASK) const logs = await Promise.all( - Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, serverUrl, data)).toArray() + Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, server, data)).toArray() ) if (args.save) { diff --git a/cli/src/utils.ts b/cli/src/utils.ts index 3c1f0dc9c..47ef2458d 100644 --- a/cli/src/utils.ts +++ b/cli/src/utils.ts @@ -1,38 +1,6 @@ -import type http from 'node:http' import fs from 'node:fs' import type { TrainerLog } from '@epfml/discojs-core' -import { Disco } from '@epfml/disco-server' - -export async function startServer (): Promise<[http.Server, URL]> { - const disco = new Disco() - await disco.addDefaultTasks() - - const server = disco.serve(8000) - await new Promise((resolve, reject) => { - server.once('listening', resolve) - server.once('error', reject) - server.on('error', console.error) - }) - - let addr: string - const rawAddr = server.address() - if (rawAddr === null) { - throw new Error('unable to get server address') - } else if (typeof rawAddr === 'string') { - addr = rawAddr - } else if (typeof rawAddr === 'object') { - if (rawAddr.family === '4') { - addr = `${rawAddr.address}:${rawAddr.port}` - } else { - addr = `[${rawAddr.address}]:${rawAddr.port}` - } - } else { - throw new Error('unable to get address to server') - } - - return [server, new URL('', `http://${addr}`)] -} export function saveLog (logs: TrainerLog[], fileName: string): void { const filePath = `./${fileName}` diff --git a/server/src/index.ts b/server/src/index.ts index b30c01adc..0dff1d1b9 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -1 +1,2 @@ export { Disco } from './get_server' +export * from './utils' diff --git a/server/tests/utils.ts b/server/src/utils.ts similarity index 95% rename from server/tests/utils.ts rename to server/src/utils.ts index df9ba0185..21e88ab25 100644 --- a/server/tests/utils.ts +++ b/server/src/utils.ts @@ -2,7 +2,7 @@ import type { Server } from 'node:http' import type { aggregator, client, Task } from '@epfml/discojs-core' -import { runDefaultServer } from '../src/get_server' +import { runDefaultServer } from './get_server' export async function startServer (): Promise { const server = await runDefaultServer() diff --git a/server/tests/client/decentralized.spec.ts b/server/tests/client/decentralized.spec.ts index 48bf560d0..490480f1d 100644 --- a/server/tests/client/decentralized.spec.ts +++ b/server/tests/client/decentralized.spec.ts @@ -3,7 +3,7 @@ import type * as http from 'http' import type { Task } from '@epfml/discojs-core' import { aggregator as aggregators, client as clients, defaultTasks } from '@epfml/discojs-core' -import { getClient, startServer } from '../utils' +import { getClient, startServer } from '../../src' const TASK = defaultTasks.titanic.getTask() diff --git a/server/tests/client/federated.spec.ts b/server/tests/client/federated.spec.ts index 207d87e51..ae50250aa 100644 --- a/server/tests/client/federated.spec.ts +++ b/server/tests/client/federated.spec.ts @@ -2,7 +2,7 @@ import type * as http from 'http' import { aggregator as aggregators, client as clients, informant, defaultTasks } from '@epfml/discojs-core' -import { getClient, startServer } from '../utils' +import { getClient, startServer } from '../../src' const TASK = defaultTasks.titanic.getTask() diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 67ad64a40..57336f4ce 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -7,7 +7,7 @@ import { aggregator as aggregators, informant as informants, client as clients, WeightsContainer, defaultTasks, aggregation } from '@epfml/discojs-core' -import { getClient, startServer } from '../utils' +import { getClient, startServer } from '../../src' // Mocked aggregators with easy-to-fetch aggregation results class MockMeanAggregator extends aggregators.MeanAggregator { diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index e81360690..78bcd65fa 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -11,7 +11,7 @@ import { } from '@epfml/discojs-core' import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node' -import { getClient, startServer } from '../utils' +import { getClient, startServer } from '../../src' const SCHEME = TrainingSchemes.FEDERATED