diff --git a/go/ai/document.go b/go/ai/document.go index 33fec9811..eb40f3717 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -42,20 +42,26 @@ type partKind int8 const ( partText partKind = iota - partBlob + partMedia + partData partToolRequest partToolResponse ) -// NewTextPart returns a Part containing raw string data. +// NewTextPart returns a Part containing text. func NewTextPart(text string) *Part { return &Part{kind: partText, text: text} } -// NewBlobPart returns a Part containing structured data described +// NewMediaPart returns a Part containing structured data described // by the given mimeType. -func NewBlobPart(mimeType, contents string) *Part { - return &Part{kind: partBlob, contentType: mimeType, text: contents} +func NewMediaPart(mimeType, contents string) *Part { + return &Part{kind: partMedia, contentType: mimeType, text: contents} +} + +// NewDataPart returns a Part containing raw string data. +func NewDataPart(contents string) *Part { + return &Part{kind: partData, text: contents} } // NewToolRequestPart returns a Part containing a request from @@ -76,9 +82,14 @@ func (p *Part) IsText() bool { return p.kind == partText } -// IsBlob reports whether the [Part] contains blob (non-plain-text) data. -func (p *Part) IsBlob() bool { - return p.kind == partBlob +// IsMedia reports whether the [Part] contains structured media data. +func (p *Part) IsMedia() bool { + return p.kind == partMedia +} + +// IsData reports whether the [Part] contains unstructured data. +func (p *Part) IsData() bool { + return p.kind == partData } // IsToolRequest reports whether the [Part] contains a request to run a tool. @@ -128,7 +139,7 @@ func (p *Part) MarshalJSON() ([]byte, error) { Text: p.text, } return json.Marshal(v) - case partBlob: + case partMedia: v := mediaPart{ Media: &mediaPartMedia{ ContentType: p.contentType, @@ -136,6 +147,11 @@ func (p *Part) MarshalJSON() ([]byte, error) { }, } return json.Marshal(v) + case partData: + v := dataPart{ + Data: p.text, + } + return json.Marshal(v) case partToolRequest: // TODO: make sure these types marshal/unmarshal nicely // between Go and javascript. At the very least the @@ -166,6 +182,7 @@ func (p *Part) UnmarshalJSON(b []byte) error { var s struct { Text string `json:"text,omitempty"` Media *mediaPartMedia `json:"media,omitempty"` + Data string `json:"data,omitempty"` ToolReq *ToolRequest `json:"toolreq,omitempty"` ToolResp *ToolResponse `json:"toolresp,omitempty"` } @@ -176,7 +193,7 @@ func (p *Part) UnmarshalJSON(b []byte) error { switch { case s.Media != nil: - p.kind = partBlob + p.kind = partMedia p.text = s.Media.Url p.contentType = s.Media.ContentType case s.ToolReq != nil: @@ -189,6 +206,11 @@ func (p *Part) UnmarshalJSON(b []byte) error { p.kind = partText p.text = s.Text p.contentType = "" + if s.Data != "" { + // Note: if part is completely empty, we use text by default. + p.kind = partData + p.text = s.Data + } } return nil } diff --git a/go/ai/document_test.go b/go/ai/document_test.go index 0ca412b3a..bc30eb7f4 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -46,10 +46,14 @@ func TestDocumentJSON(t *testing.T) { text: "hi", }, &Part{ - kind: partBlob, + kind: partMedia, contentType: "text/plain", text: "data:,bye", }, + &Part{ + kind: partData, + text: "somedata\x00string", + }, &Part{ kind: partToolRequest, toolRequest: &ToolRequest{ @@ -85,8 +89,10 @@ func TestDocumentJSON(t *testing.T) { switch a.kind { case partText: return a.text == b.text - case partBlob: + case partMedia: return a.contentType == b.contentType && a.text == b.text + case partData: + return a.text == b.text case partToolRequest: return reflect.DeepEqual(a.toolRequest, b.toolRequest) case partToolResponse: diff --git a/go/ai/gen.go b/go/ai/gen.go index 531c6fcbc..bf440600f 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -53,7 +53,7 @@ const ( FinishReasonUnknown FinishReason = "unknown" ) -type DataPart struct { +type dataPart struct { Data any `json:"data,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/go/core/schemas.config b/go/core/schemas.config index c7c402875..76aa30e37 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -127,7 +127,6 @@ Candidate pkg ai CandidateError pkg ai CandidateErrorCode pkg ai CandidateFinishReason pkg ai -DataPart pkg ai DocumentData pkg ai GenerateResponse pkg ai GenerateResponseChunk pkg ai @@ -159,6 +158,8 @@ MediaPart.toolRequest omit MediaPart.toolResponse omit MediaPartMedia pkg ai MediaPartMedia name mediaPartMedia +DataPart pkg ai +DataPart name dataPart ModelInfo pkg ai ModelInfoSupports pkg ai ModelInfoSupports.output type OutputFormat diff --git a/go/plugins/dotprompt/render.go b/go/plugins/dotprompt/render.go index 20fd9edeb..82bda0a89 100644 --- a/go/plugins/dotprompt/render.go +++ b/go/plugins/dotprompt/render.go @@ -184,7 +184,7 @@ func (p *Prompt) toParts(str string) []*ai.Part { media := str[m[0]+len(mediaPrefix) : m[1]-len(mediaSuffix)] url, contentType, _ := strings.Cut(media, " ") - ret = append(ret, ai.NewBlobPart(contentType, url)) + ret = append(ret, ai.NewMediaPart(contentType, url)) i = m[1] } diff --git a/go/plugins/dotprompt/render_test.go b/go/plugins/dotprompt/render_test.go index 27b98d8d9..8ffe1022e 100644 --- a/go/plugins/dotprompt/render_test.go +++ b/go/plugins/dotprompt/render_test.go @@ -115,7 +115,7 @@ func TestRenderMessages(t *testing.T) { { Role: ai.RoleUser, Content: []*ai.Part{ - ai.NewBlobPart("", "https://some.image.url/image.jpg"), + ai.NewMediaPart("", "https://some.image.url/image.jpg"), ai.NewTextPart(" Describe the image above."), }, }, @@ -137,11 +137,11 @@ func TestRenderMessages(t *testing.T) { Role: ai.RoleUser, Content: []*ai.Part{ ai.NewTextPart("Look at these images: "), - ai.NewBlobPart("", "http://1.png"), - ai.NewBlobPart("", "https://2.png"), - ai.NewBlobPart("", "data:image/jpeg;base64,abc123"), + ai.NewMediaPart("", "http://1.png"), + ai.NewMediaPart("", "https://2.png"), + ai.NewMediaPart("", "data:image/jpeg;base64,abc123"), ai.NewTextPart(" Do you like them? Here is another: "), - ai.NewBlobPart("", "http://anotherImage.png"), + ai.NewMediaPart("", "http://anotherImage.png"), }, }, }, diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index f2dedc32f..17eea79c1 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -164,7 +164,7 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { case genai.Text: p = ai.NewTextPart(string(part)) case genai.Blob: - p = ai.NewBlobPart(part.MIMEType, string(part.Data)) + p = ai.NewMediaPart(part.MIMEType, string(part.Data)) case genai.FunctionCall: p = ai.NewToolRequestPart(&ai.ToolRequest{ Name: part.Name, @@ -236,8 +236,10 @@ func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): return genai.Text(p.Text()) - case p.IsBlob(): + case p.IsMedia(): return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + case p.IsData(): + panic("googleai does not support Data parts") case p.IsToolResponse(): toolResp := p.ToolResponse() return genai.FunctionResponse{ diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index ff00c0555..f9e0c1c67 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -137,7 +137,7 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { case genai.Text: p = ai.NewTextPart(string(part)) case genai.Blob: - p = ai.NewBlobPart(part.MIMEType, string(part.Data)) + p = ai.NewMediaPart(part.MIMEType, string(part.Data)) case genai.FunctionCall: p = ai.NewToolRequestPart(&ai.ToolRequest{ Name: part.Name, @@ -204,8 +204,10 @@ func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): return genai.Text(p.Text()) - case p.IsBlob(): + case p.IsMedia(): return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + case p.IsData(): + panic("vertexai does not support Data parts") case p.IsToolResponse(): toolResp := p.ToolResponse() return genai.FunctionResponse{