Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split Go Blob parts into Media and Data. #246

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions go/ai/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,26 @@ type partKind int8

const (
partText partKind = iota
partBlob
partMedia
partData
partToolRequest
partToolResponse
)

// NewTextPart returns a Part containing raw string data.
// NewTextPart returns a Part containing text.
func NewTextPart(text string) *Part {
return &Part{kind: partText, text: text}
}

// NewBlobPart returns a Part containing structured data described
// NewMediaPart returns a Part containing structured data described
// by the given mimeType.
func NewBlobPart(mimeType, contents string) *Part {
return &Part{kind: partBlob, contentType: mimeType, text: contents}
func NewMediaPart(mimeType, contents string) *Part {
return &Part{kind: partMedia, contentType: mimeType, text: contents}
}

// NewDataPart returns a Part containing raw string data.
func NewDataPart(contents string) *Part {
return &Part{kind: partData, text: contents}
}

// NewToolRequestPart returns a Part containing a request from
Expand All @@ -76,9 +82,14 @@ func (p *Part) IsText() bool {
return p.kind == partText
}

// IsBlob reports whether the [Part] contains blob (non-plain-text) data.
func (p *Part) IsBlob() bool {
return p.kind == partBlob
// IsMedia reports whether the [Part] contains structured media data.
func (p *Part) IsMedia() bool {
return p.kind == partMedia
}

// IsData reports whether the [Part] contains unstructured data.
func (p *Part) IsData() bool {
return p.kind == partData
}

// IsToolRequest reports whether the [Part] contains a request to run a tool.
Expand Down Expand Up @@ -128,14 +139,19 @@ func (p *Part) MarshalJSON() ([]byte, error) {
Text: p.text,
}
return json.Marshal(v)
case partBlob:
case partMedia:
v := mediaPart{
Media: &mediaPartMedia{
ContentType: p.contentType,
Url: p.text,
},
}
return json.Marshal(v)
case partData:
v := dataPart{
Data: p.text,
}
return json.Marshal(v)
case partToolRequest:
// TODO: make sure these types marshal/unmarshal nicely
// between Go and javascript. At the very least the
Expand Down Expand Up @@ -166,6 +182,7 @@ func (p *Part) UnmarshalJSON(b []byte) error {
var s struct {
Text string `json:"text,omitempty"`
Media *mediaPartMedia `json:"media,omitempty"`
Data string `json:"data,omitempty"`
ToolReq *ToolRequest `json:"toolreq,omitempty"`
ToolResp *ToolResponse `json:"toolresp,omitempty"`
}
Expand All @@ -176,7 +193,7 @@ func (p *Part) UnmarshalJSON(b []byte) error {

switch {
case s.Media != nil:
p.kind = partBlob
p.kind = partMedia
p.text = s.Media.Url
p.contentType = s.Media.ContentType
case s.ToolReq != nil:
Expand All @@ -189,6 +206,11 @@ func (p *Part) UnmarshalJSON(b []byte) error {
p.kind = partText
p.text = s.Text
p.contentType = ""
if s.Data != "" {
// Note: if part is completely empty, we use text by default.
p.kind = partData
p.text = s.Data
}
}
return nil
}
Expand Down
10 changes: 8 additions & 2 deletions go/ai/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,14 @@ func TestDocumentJSON(t *testing.T) {
text: "hi",
},
&Part{
kind: partBlob,
kind: partMedia,
contentType: "text/plain",
text: "data:,bye",
},
&Part{
kind: partData,
text: "somedata\x00string",
},
&Part{
kind: partToolRequest,
toolRequest: &ToolRequest{
Expand Down Expand Up @@ -85,8 +89,10 @@ func TestDocumentJSON(t *testing.T) {
switch a.kind {
case partText:
return a.text == b.text
case partBlob:
case partMedia:
return a.contentType == b.contentType && a.text == b.text
case partData:
return a.text == b.text
case partToolRequest:
return reflect.DeepEqual(a.toolRequest, b.toolRequest)
case partToolResponse:
Expand Down
2 changes: 1 addition & 1 deletion go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ const (
FinishReasonUnknown FinishReason = "unknown"
)

type DataPart struct {
type dataPart struct {
Data any `json:"data,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
Expand Down
3 changes: 2 additions & 1 deletion go/core/schemas.config
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ Candidate pkg ai
CandidateError pkg ai
CandidateErrorCode pkg ai
CandidateFinishReason pkg ai
DataPart pkg ai
DocumentData pkg ai
GenerateResponse pkg ai
GenerateResponseChunk pkg ai
Expand Down Expand Up @@ -159,6 +158,8 @@ MediaPart.toolRequest omit
MediaPart.toolResponse omit
MediaPartMedia pkg ai
MediaPartMedia name mediaPartMedia
DataPart pkg ai
DataPart name dataPart
ModelInfo pkg ai
ModelInfoSupports pkg ai
ModelInfoSupports.output type OutputFormat
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (p *Prompt) toParts(str string) []*ai.Part {

media := str[m[0]+len(mediaPrefix) : m[1]-len(mediaSuffix)]
url, contentType, _ := strings.Cut(media, " ")
ret = append(ret, ai.NewBlobPart(contentType, url))
ret = append(ret, ai.NewMediaPart(contentType, url))

i = m[1]
}
Expand Down
10 changes: 5 additions & 5 deletions go/plugins/dotprompt/render_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestRenderMessages(t *testing.T) {
{
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewBlobPart("", "https://some.image.url/image.jpg"),
ai.NewMediaPart("", "https://some.image.url/image.jpg"),
ai.NewTextPart(" Describe the image above."),
},
},
Expand All @@ -137,11 +137,11 @@ func TestRenderMessages(t *testing.T) {
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewTextPart("Look at these images: "),
ai.NewBlobPart("", "http://1.png"),
ai.NewBlobPart("", "https://2.png"),
ai.NewBlobPart("", ""),
ai.NewMediaPart("", "http://1.png"),
ai.NewMediaPart("", "https://2.png"),
ai.NewMediaPart("", ""),
ai.NewTextPart(" Do you like them? Here is another: "),
ai.NewBlobPart("", "http://anotherImage.png"),
ai.NewMediaPart("", "http://anotherImage.png"),
},
},
},
Expand Down
6 changes: 4 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
p = ai.NewMediaPart(part.MIMEType, string(part.Data))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Expand Down Expand Up @@ -236,8 +236,10 @@ func convertPart(p *ai.Part) genai.Part {
switch {
case p.IsText():
return genai.Text(p.Text())
case p.IsBlob():
case p.IsMedia():
return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())}
case p.IsData():
panic("googleai does not support Data parts")
case p.IsToolResponse():
toolResp := p.ToolResponse()
return genai.FunctionResponse{
Expand Down
6 changes: 4 additions & 2 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate {
case genai.Text:
p = ai.NewTextPart(string(part))
case genai.Blob:
p = ai.NewBlobPart(part.MIMEType, string(part.Data))
p = ai.NewMediaPart(part.MIMEType, string(part.Data))
case genai.FunctionCall:
p = ai.NewToolRequestPart(&ai.ToolRequest{
Name: part.Name,
Expand Down Expand Up @@ -204,8 +204,10 @@ func convertPart(p *ai.Part) genai.Part {
switch {
case p.IsText():
return genai.Text(p.Text())
case p.IsBlob():
case p.IsMedia():
return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())}
case p.IsData():
panic("vertexai does not support Data parts")
case p.IsToolResponse():
toolResp := p.ToolResponse()
return genai.FunctionResponse{
Expand Down
Loading