Skip to content

Commit

Permalink
gRPC client stubs
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler committed Dec 4, 2024
1 parent baf7bb6 commit 83c6a4c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
2 changes: 1 addition & 1 deletion backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
string audio_output = 4;
bytes audio = 5;
}

message ModelOptions {
Expand Down
17 changes: 13 additions & 4 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ var sessionLock sync.Mutex
// TODO: implement interface as we start to define usages
type Model interface {
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
}

func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
Expand Down Expand Up @@ -800,7 +802,17 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, *
// 4. Convert the response text to speech (audio)
//
// Placeholder implementation:
// TODO: use session.ModelInterface...

// TODO: template eventual messages, like chat.go
reply, err := session.ModelInterface.Predict(context.Background(), &proto.PredictOptions{
Prompt: "What's the weather in New York?",
})

if err != nil {
return "", nil, nil, err
}

generatedAudio := reply.Audio

transcribedText := "What's the weather in New York?"
var functionCall *FunctionCall
Expand All @@ -819,9 +831,6 @@ func processAudioResponse(session *Session, audioData []byte) (string, []byte, *

// Generate a response
generatedText := "This is a response to your speech input."
generatedAudio := []byte{} // Generate audio bytes from the generatedText

// TODO: Implement actual transcription and TTS

return generatedText, generatedAudio, nil, nil
}
Expand Down
26 changes: 26 additions & 0 deletions core/http/endpoints/openai/realtime_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
"google.golang.org/grpc"
)

var (
_ Model = new(wrappedModel)
_ Model = new(anyToAnyModel)
)

// wrappedModel represent a model which does not support Any-to-Any operations
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
Expand Down Expand Up @@ -47,6 +52,27 @@ func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...g
return m.VADClient.VAD(ctx, in)
}

func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
// sound.BufferAsWAV(audioData, "audio.wav")

return m.LLMClient.Predict(ctx, in)
}

func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)

return m.LLMClient.PredictStream(ctx, in, f)
}

func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) {
return m.LLMClient.Predict(ctx, in)
}

func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
return m.LLMClient.PredictStream(ctx, in, f)
}

// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {

Expand Down
2 changes: 1 addition & 1 deletion pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ type Backend interface {
IsBusy() bool
HealthCheck(ctx context.Context) (bool, error)
Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error)
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error)
PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...grpc.CallOption) error
Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error)
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
Expand Down

0 comments on commit 83c6a4c

Please sign in to comment.