Skip to content

Commit

Permalink
server: export starter
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 1, 2024
1 parent fd2e016 commit 8eede14
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 43 deletions.
16 changes: 10 additions & 6 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { Range } from 'immutable'
import type { Server } from 'node:http'

import type { TrainerLog, data, Task } from '@epfml/discojs-core'
import { Disco, TrainingSchemes } from '@epfml/discojs-core'
import { Disco, TrainingSchemes, aggregator as aggregators, client as clients } from '@epfml/discojs-core'
import { getClient, startServer } from '@epfml/disco-server'

import { startServer, saveLog } from './utils'
import { saveLog } from './utils'
import { getTaskData } from './data'
import { args } from './args'

Expand All @@ -15,23 +17,25 @@ console.log(infoText)

console.log({ args })

async function runUser (task: Task, url: URL, data: data.DataSplit): Promise<TrainerLog> {
async function runUser (task: Task, server: Server, data: data.DataSplit): Promise<TrainerLog> {
const client = await getClient(clients.federated.FederatedClient, server, task, new aggregators.MeanAggregator(TASK))

// force the federated scheme
const scheme = TrainingSchemes.FEDERATED
const disco = new Disco(task, { scheme, url })
const disco = new Disco(task, { scheme, client })

await disco.fit(data)
await disco.close()
return await disco.logs()
}

async function main (): Promise<void> {
const [server, serverUrl] = await startServer()
const server = await startServer()

const data = await getTaskData(TASK)

const logs = await Promise.all(
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, serverUrl, data)).toArray()
Range(0, NUMBER_OF_USERS).map(async (_) => await runUser(TASK, server, data)).toArray()
)

if (args.save) {
Expand Down
32 changes: 0 additions & 32 deletions cli/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,6 @@
import type http from 'node:http'
import fs from 'node:fs'

import type { TrainerLog } from '@epfml/discojs-core'
import { Disco } from '@epfml/disco-server'

export async function startServer (): Promise<[http.Server, URL]> {
const disco = new Disco()
await disco.addDefaultTasks()

const server = disco.serve(8000)
await new Promise((resolve, reject) => {
server.once('listening', resolve)
server.once('error', reject)
server.on('error', console.error)
})

let addr: string
const rawAddr = server.address()
if (rawAddr === null) {
throw new Error('unable to get server address')
} else if (typeof rawAddr === 'string') {
addr = rawAddr
} else if (typeof rawAddr === 'object') {
if (rawAddr.family === '4') {
addr = `${rawAddr.address}:${rawAddr.port}`
} else {
addr = `[${rawAddr.address}]:${rawAddr.port}`
}
} else {
throw new Error('unable to get address to server')
}

return [server, new URL('', `http://${addr}`)]
}

export function saveLog (logs: TrainerLog[], fileName: string): void {
const filePath = `./${fileName}`
Expand Down
1 change: 1 addition & 0 deletions server/src/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export { Disco } from './get_server'
export * from './utils'
2 changes: 1 addition & 1 deletion server/tests/utils.ts → server/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { Server } from 'node:http'

import type { aggregator, client, Task } from '@epfml/discojs-core'

import { runDefaultServer } from '../src/get_server'
import { runDefaultServer } from './get_server'

export async function startServer (): Promise<Server> {
const server = await runDefaultServer()
Expand Down
2 changes: 1 addition & 1 deletion server/tests/client/decentralized.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type * as http from 'http'
import type { Task } from '@epfml/discojs-core'
import { aggregator as aggregators, client as clients, defaultTasks } from '@epfml/discojs-core'

import { getClient, startServer } from '../utils'
import { getClient, startServer } from '../../src'

const TASK = defaultTasks.titanic.getTask()

Expand Down
2 changes: 1 addition & 1 deletion server/tests/client/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type * as http from 'http'

import { aggregator as aggregators, client as clients, informant, defaultTasks } from '@epfml/discojs-core'

import { getClient, startServer } from '../utils'
import { getClient, startServer } from '../../src'

const TASK = defaultTasks.titanic.getTask()

Expand Down
2 changes: 1 addition & 1 deletion server/tests/e2e/decentralized.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
aggregator as aggregators, informant as informants, client as clients, WeightsContainer, defaultTasks, aggregation
} from '@epfml/discojs-core'

import { getClient, startServer } from '../utils'
import { getClient, startServer } from '../../src'

// Mocked aggregators with easy-to-fetch aggregation results
class MockMeanAggregator extends aggregators.MeanAggregator {
Expand Down
2 changes: 1 addition & 1 deletion server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from '@epfml/discojs-core'
import { NodeImageLoader, NodeTabularLoader } from '@epfml/discojs-node'

import { getClient, startServer } from '../utils'
import { getClient, startServer } from '../../src'

const SCHEME = TrainingSchemes.FEDERATED

Expand Down

0 comments on commit 8eede14

Please sign in to comment.