Skip to content

Commit

Permalink
Wrap citations inside CitationMetadata (#6276)
Browse files Browse the repository at this point in the history
We were previously unwrapping Citations within CitationMetadata, but
we've decided to better align with the proto.


https://github.com/googleapis/googleapis/blob/7f9941f4ba22d6eb3bb7fa31f80aae3a1b3b957e/google/cloud/aiplatform/v1/content.proto#L346

b/368310789
  • Loading branch information
rlazo authored Sep 19, 2024
1 parent 2326592 commit d0fa299
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Base64
import com.google.firebase.vertexai.common.client.Schema
import com.google.firebase.vertexai.common.server.CitationSources
import com.google.firebase.vertexai.common.shared.Blob
import com.google.firebase.vertexai.common.shared.FileData
import com.google.firebase.vertexai.common.shared.FunctionCall
Expand All @@ -32,6 +31,7 @@ import com.google.firebase.vertexai.type.BlobPart
import com.google.firebase.vertexai.type.BlockReason
import com.google.firebase.vertexai.type.BlockThreshold
import com.google.firebase.vertexai.type.Candidate
import com.google.firebase.vertexai.type.Citation
import com.google.firebase.vertexai.type.CitationMetadata
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
Expand Down Expand Up @@ -181,7 +181,7 @@ internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toStrin

internal fun com.google.firebase.vertexai.common.server.Candidate.toPublic(): Candidate {
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.toPublic()
val finishReason = finishReason.toPublic()

return Candidate(
Expand Down Expand Up @@ -228,8 +228,11 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
}
}

internal fun CitationSources.toPublic() =
CitationMetadata(startIndex = startIndex, endIndex = endIndex, uri = uri ?: "", license = license)
internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic() =
Citation(startIndex = startIndex, endIndex = endIndex, uri = uri, license = license)

internal fun com.google.firebase.vertexai.common.server.CitationMetadata.toPublic() =
CitationMetadata(citationSources.map { it.toPublic() })

internal fun com.google.firebase.vertexai.common.server.SafetyRating.toPublic() =
SafetyRating(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Candidate
internal constructor(
val content: Content,
val safetyRatings: List<SafetyRating>,
val citationMetadata: List<CitationMetadata>,
val citationMetadata: CitationMetadata?,
val finishReason: FinishReason?
)

Expand All @@ -37,15 +37,25 @@ internal constructor(
)

/**
* Provides citation metadata for sourcing of content provided by the model between a given
* A collection of source attributions for a piece of content.
*
* @property citations A list of individual cited sources and the parts of the content to which they
* apply.
*/
class CitationMetadata internal constructor(val citations: List<Citation>)

/**
* Provides citation information for sourcing of content provided by the model between a given
* [startIndex] and [endIndex].
*
* @property startIndex The beginning of the citation.
* @property endIndex The end of the citation.
* @property uri The URI of the cited work.
* @property license The license under which the cited work is distributed.
* @property startIndex The inclusive beginning of a sequence in a model response that derives from
* a cited source.
* @property endIndex The exclusive end of a sequence in a model response that derives from a cited
* source.
* @property uri A link to the cited source, if available.
* @property license The license the cited source work is distributed under, if specified.
*/
class CitationMetadata
class Citation
internal constructor(
val startIndex: Int = 0,
val endIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ internal class StreamingSnapshotTests {

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true
responseList.any {
it.candidates.any { it.citationMetadata?.citations?.isNotEmpty() ?: false }
} shouldBe true
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().citationMetadata.size shouldBe 3
response.candidates.first().citationMetadata?.citations?.size shouldBe 3
}
}

Expand All @@ -240,11 +240,14 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")

response.candidates.isEmpty() shouldBe false
response.candidates.first().citationMetadata.isEmpty() shouldBe false
response.candidates.first().citationMetadata?.citations?.isEmpty() shouldBe false
// Verify the values in the citation source
with(response.candidates.first().citationMetadata.first()) {
license shouldBe null
startIndex shouldBe 0
val firstCitation = response.candidates.first().citationMetadata?.citations?.first()
if (firstCitation != null) {
with(firstCitation) {
license shouldBe null
startIndex shouldBe 0
}
}
}
}
Expand Down

0 comments on commit d0fa299

Please sign in to comment.