Skip to content

Commit

Permalink
(fix): expression measurement is migrated and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
fern-bot committed Apr 23, 2024
1 parent 9e1883c commit 4a13cc1
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 99 deletions.
4 changes: 1 addition & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
export * as Hume from "./api";
export { HumeBatchClient } from "./wrapper/HumeBatchClient";
export { HumeStreamingClient } from "./wrapper/HumeStreamingClient";
export { Job } from "./wrapper/Job";
export { HumeClient } from "./wrapper/HumeClient";
export { HumeEnvironment } from "./environments";
export { HumeError, HumeTimeoutError } from "./errors";
10 changes: 0 additions & 10 deletions src/wrapper/HumeBatchClient.ts

This file was deleted.

10 changes: 10 additions & 0 deletions src/wrapper/HumeClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { HumeClient as FernClient } from "../Client";
import { ExpressionMeasurement } from "./expressionMeasurement/ExpressionMeasurementClient";

export class HumeClient extends FernClient {
protected _expressionMeasurement: ExpressionMeasurement | undefined;

public get expressionMeasurement(): ExpressionMeasurement {
return (this._expressionMeasurement ??= new ExpressionMeasurement(this._options));
}
}
17 changes: 17 additions & 0 deletions src/wrapper/expressionMeasurement/ExpressionMeasurementClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { ExpressionMeasurement as FernClient } from "../../api/resources/expressionMeasurement/client/Client";
import { BatchClient } from "./batch/BatchClient";
import { StreamClient } from "./streaming/StreamingClient";

export class ExpressionMeasurement extends FernClient {
protected _batch: BatchClient | undefined;

public get batch(): BatchClient {
return (this._batch ??= new BatchClient(this._options));
}

protected _stream: StreamClient | undefined;

public get stream(): StreamClient {
return (this._stream ??= new StreamClient(this._options));
}
}
13 changes: 13 additions & 0 deletions src/wrapper/expressionMeasurement/batch/BatchClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { Batch as FernClient } from "../../../api/resources/expressionMeasurement/resources/batch/client/Client";
import * as Hume from "../../../api";
import { Job } from "./Job";

export class BatchClient extends FernClient {
public async startInferenceJob(
request: Hume.expressionMeasurement.InferenceBaseRequest = {},
requestOptions?: FernClient.RequestOptions
): Promise<Job> {
const { jobId } = await super.startInferenceJob(request, requestOptions);
return new Job(jobId, this);
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import { createWriteStream } from "fs";
import { writeFile } from "fs/promises";
import { pipeline } from "stream/promises";
import * as Hume from "../api";
import * as errors from "../errors";
import { HumeBatchClient } from "./HumeBatchClient";
import * as Hume from "../../../api";
import * as errors from "../../../errors";
import { BatchClient } from "./BatchClient";

export class Job implements Hume.JobId {
constructor(
public readonly jobId: string,
private readonly client: HumeBatchClient
) {}
export class Job implements Hume.expressionMeasurement.JobId {
constructor(public readonly jobId: string, private readonly client: BatchClient) {}

public async awaitCompletion(timeoutInSeconds = 300): Promise<void> {
return new Promise((resolve, reject) => {
Expand All @@ -25,10 +19,7 @@ export class Job implements Hume.JobId {

class JobCompletionPoller {
private isPolling = true;
constructor(
private readonly jobId: string,
private readonly client: HumeBatchClient
) {}
constructor(private readonly jobId: string, private readonly client: BatchClient) {}

public start(onTerminal: () => void) {
this.isPolling = true;
Expand All @@ -42,10 +33,7 @@ class JobCompletionPoller {
private async poll(onTerminal: () => void): Promise<void> {
try {
const jobDetails = await this.client.getJobDetails(this.jobId);
if (
jobDetails.state.status === "COMPLETED" ||
jobDetails.state.status === "FAILED"
) {
if (jobDetails.state.status === "COMPLETED" || jobDetails.state.status === "FAILED") {
onTerminal();
this.stop();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import WebSocket from "ws";
import { v4 as uuid } from "uuid";
import { parse } from "./HumeStreamingClient";
import { parse } from "./StreamingClient";
import { base64Encode } from "./base64Encode";
import * as Hume from "../api";
import * as serializers from "../serialization";
import * as errors from "../errors";
import * as Hume from "../../../api";
import * as errors from "../../../errors";
import * as serializers from "../../../serialization";
import * as fs from "fs";

export declare namespace StreamSocket {
interface Args {
websocket: WebSocket;
config: Hume.ModelConfig;
config: Hume.expressionMeasurement.StreamDataModels;
streamWindowMs?: number;
}
}

export class StreamSocket {
readonly websocket: WebSocket;
private readonly streamWindowMs?: number;
private config: Hume.ModelConfig;
private config: Hume.expressionMeasurement.StreamDataModels;

constructor({ websocket, config, streamWindowMs }: StreamSocket.Args) {
this.websocket = websocket;
Expand All @@ -38,8 +38,8 @@ export class StreamSocket {
config,
}: {
file: fs.ReadStream | Blob;
config?: Hume.ModelConfig;
}): Promise<Hume.ModelResponse> {
config?: Hume.expressionMeasurement.StreamDataModels;
}): Promise<Hume.expressionMeasurement.StreamBurst | Hume.expressionMeasurement.StreamError> {
if (config != null) {
this.config = config;
}
Expand All @@ -61,7 +61,7 @@ export class StreamSocket {
} else {
throw new errors.HumeError({ message: `file must be one of ReadStream or Blob.` });
}
const request: Hume.ModelsInput = {
const request: Hume.expressionMeasurement.stream.StreamData = {
payloadId: uuid(),
data: contents,
models: this.config,
Expand All @@ -84,11 +84,17 @@ export class StreamSocket {
* @param config This method is intended for use with a `LanguageConfig`.
* When the socket is configured for other modalities this method will fail.
*/
public async sendText({ text, config }: { text: string; config?: Hume.ModelConfig }): Promise<Hume.ModelResponse> {
public async sendText({
text,
config,
}: {
text: string;
config?: Hume.expressionMeasurement.StreamDataModels;
}): Promise<Hume.expressionMeasurement.StreamBurst | Hume.expressionMeasurement.StreamError> {
if (config != null) {
this.config = config;
}
const request: Hume.ModelsInput = {
const request: Hume.expressionMeasurement.StreamData = {
payloadId: uuid(),
data: text,
rawText: true,
Expand Down Expand Up @@ -119,8 +125,8 @@ export class StreamSocket {
config,
}: {
landmarks: number[][][];
config?: Hume.ModelConfig;
}): Promise<Hume.ModelResponse> {
config?: Hume.expressionMeasurement.StreamDataModels;
}): Promise<Hume.expressionMeasurement.StreamBurst | Hume.expressionMeasurement.StreamError> {
const response = this.sendText({
text: base64Encode(JSON.stringify(landmarks)),
config,
Expand Down Expand Up @@ -159,26 +165,25 @@ export class StreamSocket {
this.websocket.close();
}

private async send(payload: Hume.ModelsInput): Promise<Hume.ModelResponse | void> {
private async send(
payload: Hume.expressionMeasurement.StreamData
): Promise<Hume.expressionMeasurement.SubscribeEvent | void> {
await this.tillSocketOpen();
const jsonPayload = await serializers.ModelsInput.jsonOrThrow(payload, {
const jsonPayload = await serializers.expressionMeasurement.StreamData.jsonOrThrow(payload, {
unrecognizedObjectKeys: "strip",
});
this.websocket.send(JSON.stringify(jsonPayload));
const response = await new Promise<Hume.ModelResponse | Hume.ModelsWarning | Hume.ModelsError | undefined>(
(resolve, reject) => {
this.websocket.addEventListener("message", (event) => {
const response = parse(event.data);
resolve(response);
});
}
);
const response = await new Promise<
Hume.expressionMeasurement.StreamBurst | Hume.expressionMeasurement.StreamError | undefined
>((resolve, reject) => {
this.websocket.addEventListener("message", (event) => {
const response = parse(event.data);
resolve(response);
});
});
if (response != null && isError(response)) {
throw new errors.HumeError({ message: `CODE ${response.code}: ${response.error}` });
}
if (response != null && isWarning(response)) {
throw new errors.HumeError({ message: `CODE ${response.code}: ${response.warning}` });
}
return response;
}

Expand All @@ -198,12 +203,8 @@ export class StreamSocket {
}
}

function isError(response: Hume.ModelResponse | Hume.ModelsWarning | Hume.ModelsError): response is Hume.ModelsError {
return (response as Hume.ModelsError).error != null;
}

function isWarning(
response: Hume.ModelResponse | Hume.ModelsWarning | Hume.ModelsError
): response is Hume.ModelsWarning {
return (response as Hume.ModelsWarning).warning != null;
function isError(
response: Hume.expressionMeasurement.StreamBurst | Hume.expressionMeasurement.StreamError
): response is Hume.expressionMeasurement.StreamError {
return (response as Hume.expressionMeasurement.StreamError).error != null;
}
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
import * as Hume from "../api";
import * as serializers from "../serialization";
import * as Hume from "../../../api";
import * as serializers from "../../../serialization";
import * as core from "../../../core";
import { StreamSocket } from "./StreamSocket";
import WebSocket from "ws";

export declare namespace HumeStreamingClient {
export declare namespace StreamClient {
interface Options {
apiKey: string;
/* Defaults to 10 seconds */
openTimeoutInSeconds?: number;
apiKey?: core.Supplier<string | undefined>;
}

interface ConnectArgs {
/* Job config */
config: Hume.ModelConfig;
config: Hume.expressionMeasurement.StreamDataModels;
/* Length of the sliding window in milliseconds to use when
aggregating media across streaming payloads within one WebSocket connection. */
streamWindowMs?: number;

onOpen?: (event: WebSocket.Event) => void;
onMessage?: (message: Hume.ModelResponse) => void;
onWarning?: (error: Hume.ModelsWarning) => void;
onError?: (error: Hume.ModelsError) => void;
onMessage?: (message: Hume.expressionMeasurement.stream.StreamBurst) => void;
onError?: (error: Hume.expressionMeasurement.stream.StreamError) => void;
onClose?: (event: WebSocket.Event) => void;
}
}

export class HumeStreamingClient {
constructor(protected readonly _options: HumeStreamingClient.Options) {}
export class StreamClient {
constructor(protected readonly _options: StreamClient.Options) {}

public connect(args: HumeStreamingClient.ConnectArgs): StreamSocket {
public connect(args: StreamClient.ConnectArgs): StreamSocket {
const websocket = new WebSocket(`wss://api.hume.ai/v0/stream/models`, {
headers: {
"X-Hume-Api-Key": this._options.apiKey,
"X-Hume-Api-Key": typeof this._options.apiKey === "string" ? this._options.apiKey : "",
},
timeout: this._options.openTimeoutInSeconds,
timeout: 10
});

websocket.addEventListener("open", (event) => {
Expand All @@ -50,7 +48,6 @@ export class HumeStreamingClient {
websocket.addEventListener("message", async ({ data }) => {
parse(data, {
onMessage: args.onMessage,
onWarning: args.onWarning,
onError: args.onError,
});
});
Expand All @@ -70,14 +67,13 @@ export class HumeStreamingClient {
export async function parse(
data: WebSocket.Data,
args: {
onMessage?: (message: Hume.ModelResponse) => void;
onWarning?: (error: Hume.ModelsWarning) => void;
onError?: (error: Hume.ModelsError) => void;
onMessage?: (message: Hume.expressionMeasurement.stream.StreamBurst) => void;
onError?: (error: Hume.expressionMeasurement.stream.StreamError) => void;
} = {}
): Promise<Hume.ModelResponse | Hume.ModelsWarning | Hume.ModelsError | undefined> {
): Promise<Hume.expressionMeasurement.stream.StreamBurst | Hume.expressionMeasurement.stream.StreamError | undefined> {
const message = JSON.parse(data as string);

const parsedResponse = await serializers.ModelResponse.parse(message, {
const parsedResponse = await serializers.expressionMeasurement.stream.StreamBurst.parse(message, {
unrecognizedObjectKeys: "passthrough",
allowUnrecognizedUnionMembers: true,
allowUnrecognizedEnumValues: true,
Expand All @@ -88,18 +84,7 @@ export async function parse(
return parsedResponse.value;
}

const parsedWarning = await serializers.ModelsWarning.parse(message, {
unrecognizedObjectKeys: "passthrough",
allowUnrecognizedUnionMembers: true,
allowUnrecognizedEnumValues: true,
breadcrumbsPrefix: ["response"],
});
if (parsedWarning.ok) {
args.onWarning?.(parsedWarning.value);
return parsedWarning.value;
}

const parsedError = await serializers.ModelsError.parse(message, {
const parsedError = await serializers.expressionMeasurement.stream.StreamError.parse(message, {
unrecognizedObjectKeys: "passthrough",
allowUnrecognizedUnionMembers: true,
allowUnrecognizedEnumValues: true,
Expand Down
File renamed without changes.
18 changes: 18 additions & 0 deletions tests/expressionMeasurement/batch.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { HumeClient } from "../../src/"

describe("Streaming Expression Measurement", () => {
it("Emotional Language Text", async () => {
const hume = new HumeClient({
apiKey: "<>"
});
const job = await hume.expressionMeasurement.batch.startInferenceJob({
models: {
face: {}
},
urls: ["https://hume-tutorials.s3.amazonaws.com/faces.zip"]
});
await job.awaitCompletion();
const predictions = await hume.expressionMeasurement.batch.getJobPredictions(job.jobId);
console.log(JSON.stringify(predictions, null, 2));
});
});
25 changes: 25 additions & 0 deletions tests/expressionMeasurement/streaming.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { HumeClient } from "../../src/"

const samples = [
"Mary had a little lamb,",
"Its fleece was white as snow.",
"Everywhere the child went,",
"The little lamb was sure to go."
];

describe("Streaming Expression Measurement", () => {
it.skip("Emotional Language Text", async () => {
const hume = new HumeClient({
apiKey: "<>"
});
const socket = hume.expressionMeasurement.stream.connect({
config: {
language: {}
}
})
for (const sample of samples) {
const result = await socket.sendText({ text: sample })
console.log(result)
}
}, 100000);
});

0 comments on commit 4a13cc1

Please sign in to comment.