Skip to content

Commit

Permalink
CAII endpoint discovery (#60)
Browse files Browse the repository at this point in the history
* "wip on endpoint listing"

* "wip on list_endpoints typing"

* "refactoring to endpoint object"

* "wip filtering"

* "endpoints queried!"

* "refactoring"

* "wip on cleaning up types"

* "type cleanup complete"

* "moving files"

* "use a dummy embedding model for deletes"

* fix some bits from merge, get evals working again with CAII, tests passing

* formatting

* clean up ruff stuff

* use the chat llm for evals

* fix mypy for reformatting

* "wip on java reconciler"

* "reconciler don't do no model; start python work"

* "python - updating for summarization model"

* "comment out batch embeddings to get it working again"

* add handling for no summarization in the files table

* finish up ui and python for summarization

* make sure to update the time-updated fields on data sources and chat sessions

* use no-op models when we don't need real ones for summary functionality

* Update release version to dev-testing

* use the summarization llm when summarizing summaries

---------

Co-authored-by: Elijah Williams <ewilliams@cloudera.com>
Co-authored-by: actions-user <actions@github.com>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent 2dac585 commit bdfabb7
Show file tree
Hide file tree
Showing 53 changed files with 995 additions and 338 deletions.
2 changes: 0 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ DB_URL=jdbc:h2:../databases/rag

# If using CAII, fill these in:
CAII_DOMAIN=
CAII_INFERENCE_ENDPOINT_NAME=
CAII_EMBEDDING_ENDPOINT_NAME=

# set this to true if you have uv installed on your system, other wise don't include this
USE_SYSTEM_UV=true
Expand Down
8 changes: 0 additions & 8 deletions .project-metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ environment_variables:
default: ""
description: "The domain of the CAII service. Setting this will enable CAII as the sole source for both inference and embedding models."
required: false
CAII_INFERENCE_ENDPOINT_NAME:
default: ""
description: "The name of the inference endpoint for the CAII service. Required if CAII_DOMAIN is set."
required: false
CAII_EMBEDDING_ENDPOINT_NAME:
default: ""
description: "The name of the embedding endpoint for the CAII service. Required if CAII_DOMAIN is set."
required: false
DB_URL:
default: "jdbc:h2:file:~/databases/rag"
description: "Internal DB URL. Do not change."
Expand Down
1 change: 1 addition & 0 deletions backend/src/main/java/com/cloudera/cai/rag/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public record RagDataSource(
Long id,
String name,
String embeddingModel,
String summarizationModel,
Integer chunkSize,
Integer chunkOverlapPercent,
Instant timeCreated,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.cloudera.cai.rag.Types.RagDataSource;
import com.cloudera.cai.rag.configuration.JdbiConfiguration;
import com.cloudera.cai.util.exceptions.NotFound;
import java.time.Instant;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.jdbi.v3.core.Jdbi;
Expand All @@ -62,8 +63,8 @@ public Long createRagDataSource(RagDataSource input) {
handle -> {
var sql =
"""
INSERT INTO rag_data_source (name, chunk_size, chunk_overlap_percent, created_by_id, updated_by_id, connection_type, embedding_model)
VALUES (:name, :chunkSize, :chunkOverlapPercent, :createdById, :updatedById, :connectionType, :embeddingModel)
INSERT INTO rag_data_source (name, chunk_size, chunk_overlap_percent, created_by_id, updated_by_id, connection_type, embedding_model, summarization_model)
VALUES (:name, :chunkSize, :chunkOverlapPercent, :createdById, :updatedById, :connectionType, :embeddingModel, :summarizationModel)
""";
try (var update = handle.createUpdate(sql)) {
update.bindMethods(input);
Expand All @@ -78,7 +79,7 @@ public void updateRagDataSource(RagDataSource input) {
var sql =
"""
UPDATE rag_data_source
SET name = :name, connection_type = :connectionType, updated_by_id = :updatedById
SET name = :name, connection_type = :connectionType, updated_by_id = :updatedById, summarization_model = :summarizationModel, time_updated = :now
WHERE id = :id AND deleted IS NULL
""";
try (var update = handle.createUpdate(sql)) {
Expand All @@ -87,6 +88,8 @@ public void updateRagDataSource(RagDataSource input) {
.bind("updatedById", input.updatedById())
.bind("connectionType", input.connectionType())
.bind("id", input.id())
.bind("summarizationModel", input.summarizationModel())
.bind("now", Instant.now())
.execute();
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ public void resync() {
log.debug("checking for RAG documents to be summarized");
String sql =
"""
SELECT * from rag_data_source_document
WHERE summary_creation_timestamp IS NULL
AND time_created > :yesterday
SELECT rdsd.* from rag_data_source_document rdsd
JOIN rag_data_source rds ON rdsd.data_source_id = rds.id
WHERE rdsd.summary_creation_timestamp IS NULL
AND (rdsd.time_created > :yesterday OR rds.time_updated > :yesterday)
AND rds.summarization_model IS NOT NULL
""";
jdbi.useHandle(
handle -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,16 @@ public void delete(Long id) {
}

public void update(Types.Session input) {
var updatedInput = input.withTimeUpdated(Instant.now());
jdbi.useHandle(
handle -> {
var sql =
"""
UPDATE CHAT_SESSION
SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks
SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks, time_updated = :timeUpdated
WHERE id = :id
""";
handle.createUpdate(sql).bindMethods(input).execute();
handle.createUpdate(sql).bindMethods(updatedInput).execute();
});
}
}
4 changes: 3 additions & 1 deletion backend/src/main/java/com/cloudera/cai/util/IdGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

package com.cloudera.cai.util;

import java.util.Random;
import java.util.UUID;
import org.springframework.stereotype.Component;

Expand All @@ -53,6 +54,7 @@ public static IdGenerator createNull(String... dummyIds) {
}

private static class NullIdGenerator extends IdGenerator {
private final Random random = new Random();

private final String[] dummyIds;

Expand All @@ -62,7 +64,7 @@ private NullIdGenerator(String[] dummyIds) {

@Override
public String generateId() {
return dummyIds.length == 0 ? "StubbedId" : dummyIds[0];
return dummyIds.length == 0 ? "StubbedId-" + random.nextInt() : dummyIds[0];
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE rag_data_source DROP COLUMN summarization_model;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE rag_data_source ADD COLUMN summarization_model varchar(255);

COMMIT;
4 changes: 3 additions & 1 deletion backend/src/main/resources/migrations/migrations.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
13_add_chat_configuration.down.sql
13_add_chat_configuration.up.sql
14_add_embedding_model.down.sql
14_add_embedding_model.up.sql
14_add_embedding_model.up.sql
15_add_summarization_model.down.sql
15_add_summarization_model.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE rag_data_source DROP COLUMN embedding_model;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE rag_data_source ADD COLUMN embedding_model varchar(255);

COMMIT;
1 change: 1 addition & 0 deletions backend/src/test/java/com/cloudera/cai/rag/TestData.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public static Types.RagDataSource createTestDataSourceInstance(
null,
name,
"test_embedding_model",
"summarizationModel",
chunkSize,
chunkOverlapPercent,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void updateName() {
newDataSource.id(),
"updated-name",
"test_embedding_model",
"summarizationModel",
newDataSource.chunkSize(),
newDataSource.chunkOverlapPercent(),
newDataSource.timeCreated(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
import static com.cloudera.cai.rag.Types.ConnectionType.MANUAL;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;

import com.cloudera.cai.rag.TestData;
import com.cloudera.cai.rag.configuration.JdbiConfiguration;
import com.cloudera.cai.util.exceptions.NotFound;
import java.time.Duration;
import java.time.Instant;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -75,21 +77,27 @@ void update() {
TestData.createTestDataSourceInstance("test-name", 512, 10, MANUAL)
.withCreatedById("abc")
.withUpdatedById("abc"));
assertThat(repository.getRagDataSourceById(id).name()).isEqualTo("test-name");
assertThat(repository.getRagDataSourceById(id).updatedById()).isEqualTo("abc");
var insertedDataSource = repository.getRagDataSourceById(id);
assertThat(insertedDataSource.name()).isEqualTo("test-name");
assertThat(insertedDataSource.updatedById()).isEqualTo("abc");
var timeInserted = insertedDataSource.timeUpdated();
assertThat(timeInserted).isNotNull();

var expectedRagDataSource =
TestData.createTestDataSourceInstance("new-name", 512, 10, API)
.withCreatedById("abc")
.withUpdatedById("def")
.withId(id)
.withDocumentCount(0);

// wait a moment so the updated time will always be later than insert time
await().atLeast(Duration.ofMillis(1));
repository.updateRagDataSource(expectedRagDataSource);
assertThat(repository.getRagDataSourceById(id))
var updatedDataSource = repository.getRagDataSourceById(id);
assertThat(updatedDataSource)
.usingRecursiveComparison()
.ignoringFieldsOfTypes(Instant.class)
.isEqualTo(expectedRagDataSource);
assertThat(updatedDataSource.timeUpdated()).isAfter(timeInserted);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void getRagDocuments() {
null,
"test_datasource",
"test_embedding_model",
"summarizationModel",
1024,
20,
null,
Expand Down
Loading

0 comments on commit bdfabb7

Please sign in to comment.