Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added wikitext task, added text-dataset for core, node and web, renamed task.taskId to task.id #619

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 53 additions & 50 deletions discojs/discojs-core/src/async_informant.ts
Original file line number Diff line number Diff line change
@@ -1,64 +1,67 @@
import { AggregatorBase } from './aggregator'

export class AsyncInformant<T> {
private _round = 0
private _currentNumberOfParticipants = 0
private _totalNumberOfParticipants = 0
private _averageNumberOfParticipants = 0
private _round = 0

Check failure on line 4 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _currentNumberOfParticipants = 0

Check failure on line 5 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _totalNumberOfParticipants = 0

Check failure on line 6 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _averageNumberOfParticipants = 0

Check failure on line 7 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

constructor (
private readonly aggregator: AggregatorBase<T>
) {}
constructor(private readonly aggregator: AggregatorBase<T>) {}

Check failure on line 9 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

Check failure on line 9 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Missing space before function parentheses

update (): void {
console.debug('before:')
this.printAllInfos()
if (this.round === 0 || this.round < this.aggregator.round) {
this._round = this.aggregator.round
this._currentNumberOfParticipants = this.aggregator.size
this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round
this._totalNumberOfParticipants += this.currentNumberOfParticipants
} else {
this._round = this.aggregator.round
update(): void {

Check failure on line 11 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

Check failure on line 11 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Missing space before function parentheses
console.debug('before:')

Check failure on line 12 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 4 spaces but found 8
this.printAllInfos()

Check failure on line 13 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 4 spaces but found 8
if (this.round === 0 || this.round < this.aggregator.round) {
this._round = this.aggregator.round
this._currentNumberOfParticipants = this.aggregator.size
this._averageNumberOfParticipants =
this.totalNumberOfParticipants / this.round
this._totalNumberOfParticipants += this.currentNumberOfParticipants
} else {
this._round = this.aggregator.round
}
console.debug('after:')
this.printAllInfos()
}
console.debug('after:')
this.printAllInfos()
}

// Getter functions
get round (): number {
return this._round
}
// Getter functions
get round(): number {
return this._round
}

get currentNumberOfParticipants (): number {
return this._currentNumberOfParticipants
}
get currentNumberOfParticipants(): number {
return this._currentNumberOfParticipants
}

get totalNumberOfParticipants (): number {
return this._totalNumberOfParticipants
}
get totalNumberOfParticipants(): number {
return this._totalNumberOfParticipants
}

get averageNumberOfParticipants (): number {
return this._averageNumberOfParticipants
}
get averageNumberOfParticipants(): number {
return this._averageNumberOfParticipants
}

getAllStatistics (): Record<
'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number
> {
return {
round: this.round,
currentNumberOfParticipants: this.currentNumberOfParticipants,
totalNumberOfParticipants: this.totalNumberOfParticipants,
averageNumberOfParticipants: this.averageNumberOfParticipants
getAllStatistics(): Record<
| 'round'
| 'currentNumberOfParticipants'
| 'totalNumberOfParticipants'
| 'averageNumberOfParticipants',
number
> {
return {
round: this.round,
currentNumberOfParticipants: this.currentNumberOfParticipants,
totalNumberOfParticipants: this.totalNumberOfParticipants,
averageNumberOfParticipants: this.averageNumberOfParticipants,
}
}
}

// Debug
public printAllInfos (): void {
console.debug('task:', this.aggregator.task.taskID)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
console.debug('average:', this.averageNumberOfParticipants)
}
// Debug
public printAllInfos(): void {
console.debug('task:', this.aggregator.task.id)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
console.debug('average:', this.averageNumberOfParticipants)
}
}
198 changes: 102 additions & 96 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import { Set } from 'immutable'
import axios from 'axios'

import { tf, Task, TrainingInformant, serialization, WeightsContainer } from '..'
import {
tf,
Task,
TrainingInformant,
serialization,
WeightsContainer,
} from '..'
import { NodeID } from './types'
import { EventConnection } from './event_connection'
import { Aggregator } from '../aggregator'
Expand All @@ -11,119 +17,119 @@ import { Aggregator } from '../aggregator'
* communication with other nodes, be it peers or a server.
*/
export abstract class Base {
/**
* Own ID provided by the network's server.
*/
protected _ownId?: NodeID
/**
* The network's server.
*/
protected _server?: EventConnection
/**
* The aggregator's result produced after aggregation.
*/
protected aggregationResult?: Promise<WeightsContainer>

constructor (
/**
* The network server's URL to connect to.
* Own ID provided by the network's server.
*/
public readonly url: URL,
protected _ownId?: NodeID
/**
* The client's corresponding task.
* The network's server.
*/
public readonly task: Task,
protected _server?: EventConnection
/**
* The client's aggregator.
* The aggregator's result produced after aggregation.
*/
public readonly aggregator: Aggregator
) {}
protected aggregationResult?: Promise<WeightsContainer>

/**
* Handles the connection process from the client to any sort of network server.
*/
async connect (): Promise<void> {}
constructor(
/**
* The network server's URL to connect to.
*/
public readonly url: URL,
/**
* The client's corresponding task.
*/
public readonly task: Task,
/**
* The client's aggregator.
*/
public readonly aggregator: Aggregator
) {}

/**
* Handles the disconnection process of the client from any sort of network server.
*/
async disconnect (): Promise<void> {}
/**
* Handles the connection process from the client to any sort of network server.
*/
async connect(): Promise<void> {}

/**
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<tf.LayersModel> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
}
url.pathname += `tasks/${this.task.taskID}/model.json`
/**
* Handles the disconnection process of the client from any sort of network server.
*/
async disconnect(): Promise<void> {}

const response = await axios.get(url.href)
/**
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel(): Promise<tf.LayersModel> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
}
url.pathname += `tasks/${this.task.id}/model.json`

return await serialization.model.decode(response.data)
}
const response = await axios.get(url.href)

/**
* Communication callback called once at the beginning of the training instance.
* @param weights The initial model weights
* @param trainingInformant The training informant
*/
async onTrainBeginCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}
return await serialization.model.decode(response.data)
}

/**
* Communication callback called once at the end of the training instance.
* @param weights The final model weights
* @param trainingInformant The training informant
*/
async onTrainEndCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called once at the beginning of the training instance.
* @param weights The initial model weights
* @param trainingInformant The training informant
*/
async onTrainBeginCommunication(
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called at the beginning of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundBeginCommunication (
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called once at the end of the training instance.
* @param weights The final model weights
* @param trainingInformant The training informant
*/
async onTrainEndCommunication(
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called the end of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundEndCommunication (
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called at the beginning of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundBeginCommunication(
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}

get nodes (): Set<NodeID> {
return this.aggregator.nodes
}
/**
* Communication callback called the end of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundEndCommunication(
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}

get nodes(): Set<NodeID> {
return this.aggregator.nodes
}

get ownId (): NodeID {
if (this._ownId === undefined) {
throw new Error('the node is not connected')
get ownId(): NodeID {
if (this._ownId === undefined) {
throw new Error('the node is not connected')
}
return this._ownId
}
return this._ownId
}

get server (): EventConnection {
if (this._server === undefined) {
throw new Error('server undefined, not connected')
get server(): EventConnection {
if (this._server === undefined) {
throw new Error('server undefined, not connected')
}
return this._server
}
return this._server
}
}
Loading
Loading