Skip to content

Commit

Permalink
Fix #81, implement #85 (#98)
Browse files Browse the repository at this point in the history
* Fixes #81

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>

* Implements #85

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>

* Add logging support and fix trunk errors

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>

* Fixes #99

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>

* Added documentation

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>

---------

Signed-off-by: JonahSussman <sussmanjonah@gmail.com>
  • Loading branch information
JonahSussman authored Mar 22, 2024
1 parent faf38f6 commit d709950
Show file tree
Hide file tree
Showing 17 changed files with 244 additions and 155 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,47 @@ Note: For purposes of this initial prototype we are using an example of Java EE
- To access via an API you can look at ‘Documentation’ after logging into https://bam.res.ibm.com/
- You will see a field embedded in the 'Documentation' section where you can generate/obtain an API Key.

##### Selecting Other Models

We also support other models. To change which llm you are targeting, open `config.toml` and change the `[models]` section to one of the following:

**IBM served granite**

```toml
provider = "IBMGranite"
args = { model_id = "ibm/granite-13b-chat-v2" }
```

**IBM served mistral**

```toml
provider = "IBMOpenSource"
args = { model_id = "ibm-mistralai/mixtral-8x7b-instruct-v01-q" }
```

**IBM served codellama**

```toml
provider = "IBMOpenSource"
args = { model_id = "meta-llama/llama-2-13b-chat" }
```

**OpenAI GPT 3.5**

```toml
provider = "OpenAI"
args = { model_id = "gpt-4" }
```

**OpenAI GPT 4**

```toml
provider = "OpenAI"
args = { model_id = "gpt-3.5-turbo" }
```

provider = "IBMGranite"

### Demo Steps

#### Backend
Expand Down
61 changes: 0 additions & 61 deletions kai-service/mock-client.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,66 +276,6 @@ async def main():
{
"violation_name": "jms-to-reactive-quarkus-00010",
"ruleset_name": "kai/quarkus",
"incident_snip": """ 1 /*
2 * JBoss, Home of Professional Open Source
3 * Copyright 2015, Red Hat, Inc. and/or its affiliates, and individual
4 * contributors by the @authors tag. See the copyright.txt in the
5 * distribution for a full listing of individual contributors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the \"License\");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an \"AS IS\" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.jboss.as.quickstarts.cmt.mdb;
18
19 import java.util.logging.Logger;
20
21 import javax.ejb.ActivationConfigProperty;
22 import javax.ejb.MessageDriven;
23 import javax.jms.JMSException;
24 import javax.jms.Message;
25 import javax.jms.MessageListener;
26 import javax.jms.TextMessage;
27
28 /**
29 * <p>
30 * A simple Message Driven Bean that asynchronously receives and processes the messages that are sent to the queue.
31 * </p>
32 *
33 * @author Serge Pagop (spagop@redhat.com)
34 *
35 */
36 @MessageDriven(name = \"HelloWorldMDB\", activationConfig = {
37 @ActivationConfigProperty(propertyName = \"destinationType\", propertyValue = \"javax.jms.Queue\"),
38 @ActivationConfigProperty(propertyName = \"destination\", propertyValue = \"queue/CMTQueue\"),
39 @ActivationConfigProperty(propertyName = \"acknowledgeMode\", propertyValue = \"Auto-acknowledge\") })
40 public class HelloWorldMDB implements MessageListener {
41
42 private static final Logger logManager = Logger.getLogger(HelloWorldMDB.class.toString());
43
44 /**
45 * @see MessageListener#onMessage(Message)
46 */
47 public void onMessage(Message receivedMessage) {
48 TextMessage textMsg = null;
49 try {
50 if (receivedMessage instanceof TextMessage) {
51 textMsg = (TextMessage) receivedMessage;
52 logManager.info(\"Received Message: \" + textMsg.getText());
53 } else {
54 logManager.warning(\"Message of wrong type: \" + receivedMessage.getClass().getName());
55 }
56 } catch (JMSException ex) {
57 throw new RuntimeException(ex);
58 }
59 }
60 }""",
"incident_variables": {
"file": "file:///tmp/source-code/src/main/java/org/jboss/as/quickstarts/cmt/mdb/HelloWorldMDB.java",
"kind": "Class",
Expand All @@ -349,7 +289,6 @@ async def main():
{
"violation_name": "change_variables",
"ruleset_name": "kai/funny",
"incident_snip": "",
"incident_variables": {
"file": "file:///tmp/source-code/src/main/java/org/jboss/as/quickstarts/cmt/mdb/HelloWorldMDB.java",
"kind": "Class",
Expand Down
30 changes: 30 additions & 0 deletions kai/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[postgresql]
host = "127.0.0.1"
database = "kai"
user = "kai"
password = "dog8code"

[models]
# How to run with: (look at model_provider.py for more info)
# IBM served granite
# provider = "IBMGranite"
# args = { model_id = "ibm/granite-13b-chat-v2" }
# IBM served mistral
# provider = "IBMOpenSource"
# args = { model_id = "ibm-mistralai/mixtral-8x7b-instruct-v01-q" }
# IBM served codellama
# provider = "IBMOpenSource"
# args = { model_id = "meta-llama/llama-2-13b-chat" }
# OpenAI GPT 3.5
# provider = "OpenAI"
# args = { model_id = "gpt-4" }
# OpenAI GPT 4
# provider = "OpenAI"
# args = { model_id = "gpt-3.5-turbo" }
provider = "IBMGranite"
args = { model_id = "ibm/granite-13b-chat-v2" }

# Here for later, we want to be able to configure which embeddings are used when
# we start to integrate them into the project
[embeddings]
todo = true
36 changes: 36 additions & 0 deletions kai/data/jsonschema/server_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"type": "object",
"properties": {
"postgresql": {
"type": "object",
"properties": {
"host": {
"type": "string"
},
"database": {
"type": "string"
},
"user": {
"type": "string"
},
"password": {
"type": "string"
}
},
"required": ["host", "database", "user", "password"]
},
"models": {
"type": "object",
"properties": {
"provider": {
"type": "string"
},
"args": {
"type": "object"
}
},
"required": ["provider", "args"]
}
},
"required": ["postgresql", "models"]
}
5 changes: 1 addition & 4 deletions kai/data/misc/ai-test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from langchain import PromptTemplate
from langchain.callbacks import FileCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI

template = """
You are an excellent enterprise architect who has an extensive
Expand Down
5 changes: 0 additions & 5 deletions kai/database.ini

This file was deleted.

9 changes: 6 additions & 3 deletions kai/embedding_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import random
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -106,11 +107,11 @@ def get_embedding(self, inp: str) -> list | None:
timeout=50,
)
except requests.Timeout:
print("Error: timout after 50 seconds")
logging.error("Error: timeout after 50 seconds")
return None

if response.status_code != 200:
print("Error:", response.status_code, response.text)
logging.error("Error:", response.status_code, response.text)
return None

return response.json()["data"][0]["embedding"]
Expand Down Expand Up @@ -140,7 +141,9 @@ def get_embedding(self, inp: str) -> list | None:
toks = self.instructor.tokenize([inp])["input_ids"].tolist()[0]

if len(toks) > self.max_tokens:
print(f"Length of tokens is {len(toks)}. Truncating to {self.max_tokens}.")
logging.info(
f"Length of tokens is {len(toks)}. Truncating to {self.max_tokens}."
)

result: numpy.ndarray = self.instructor.encode(prompt).tolist()[0]
return result
Expand Down
Loading

0 comments on commit d709950

Please sign in to comment.