Skip to content

Commit

Permalink
Merge pull request #65 from mjuCapstone/feat/#64
Browse files Browse the repository at this point in the history
Feat/#64 : GPT 요청 방식을 수정합니다.
  • Loading branch information
hyunw9 authored Aug 19, 2024
2 parents 0dfab86 + fb6ca65 commit 55ca1e4
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 38 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/deployment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ jobs:
echo "${{ secrets.PROPERTIES }}" > ./application.yml
shell: bash

- name: make model foundation
if: contains(github.ref, 'develop')
run: |
cd ./src/main/resources
touch ./instruction.txt
echo "${{ secrets.ENV_INSTRUCTION }}" > ./instruction.txt
touch ./menu_final.txt
echo "${{ secrets.ENV_MENU_FINAL }}" > ./menu_final.txt
shell: bash

# gradle build 위한 permission
- name: Grant execute permission for gradlew
run: chmod +x ./gradlew
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ foodData.csv
Untitled-1.py
/src/main/resources/secret.yml
/menu-data
*.txt
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ dependencies {

//WebFlux
// implementation 'org.springframework:spring-webflux:6.1.6'

//azure
implementation group: 'com.azure', name: 'azure-ai-openai-assistants', version: '1.0.0-beta.3'

}

tasks.named('test') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ public List<TotalRecommendResponse> getChatResponse(MenuRecommendRequest menuRec
log.info(prompt);

List<Menu> response = gptManager.sendOpenAIRequest(prompt);

// List<Menu> recommends = response.menus();
log.info("response size : " + response.size());

return response.stream().map(this::getTotalRecommendResponseByFoodName)
.collect(Collectors.toUnmodifiableList());
Expand All @@ -55,8 +54,7 @@ private String createNutritionPrompt(MenuRecommendRequest request,
prefPrompt = "사용자는 " + tasteType + " " + menuCountry + " " + ingredient + " 음식을 선호해.";
}
String result = String.format(
"사용자가 %s으로 %s 해 먹을 식단을 업로드한 파일 내에서 추천해줘. 탄수화물 %dg, 단백질 %dg, 지방 %dg 을 섭취해야 해." + prefPrompt
+ " 응답 형식은 다른 말 없이 무조건 다음과 같아야 해 : JSON [String , int]",
"사용자가 %s으로 %s로 먹어야해. 탄수화물 %dg, 단백질 %dg, 지방 %dg 을 섭취해야 해." + prefPrompt,
request.mealTime(),request.cookOrDelivery(),supposedNutrition.carbohydrate(),
supposedNutrition.protein(),supposedNutrition.fat()
);
Expand All @@ -65,6 +63,7 @@ private String createNutritionPrompt(MenuRecommendRequest request,
}

public TotalRecommendResponse getTotalRecommendResponseByFoodName(Menu response) {
log.info(response.name());
Food food = staticFoodService.findFoodByName(response.name());
int amount = response.amount();
int kcal = (int) ((double) food.getKcal() * amount / 100);
Expand Down
91 changes: 91 additions & 0 deletions src/main/java/com/mju/capstone/recommend/config/AzureConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.mju.capstone.recommend.config;


import com.azure.ai.openai.assistants.AssistantsClient;
import com.azure.ai.openai.assistants.AssistantsClientBuilder;
import com.azure.ai.openai.assistants.AssistantsServiceVersion;
import com.azure.ai.openai.assistants.models.Assistant;
import com.azure.ai.openai.assistants.models.AssistantCreationOptions;
import com.azure.ai.openai.assistants.models.CreateFileSearchToolResourceOptions;
import com.azure.ai.openai.assistants.models.CreateFileSearchToolResourceVectorStoreOptions;
import com.azure.ai.openai.assistants.models.CreateFileSearchToolResourceVectorStoreOptionsList;
import com.azure.ai.openai.assistants.models.CreateToolResourcesOptions;
import com.azure.ai.openai.assistants.models.FileDetails;
import com.azure.ai.openai.assistants.models.FilePurpose;
import com.azure.ai.openai.assistants.models.FileSearchToolDefinition;
import com.azure.ai.openai.assistants.models.OpenAIFile;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.BinaryData;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
@RequiredArgsConstructor
@Slf4j
public class AzureConfig {

@Value("${azure.credential}")
private String credentialKey;
@Value("${azure.endpoint}")
private String endpoint;
@Value("${azure.key}")
private String key;
@Value("${azure.model}")
private String model;

@Bean
AssistantsClient assistantsClient() {
return new AssistantsClientBuilder()
.credential(new AzureKeyCredential(key))
.serviceVersion(AssistantsServiceVersion.getLatest())
.endpoint(endpoint)
.buildClient();
}

@Bean
public Assistant customAssistant(AssistantsClient client) throws IOException {

Path filePath = Paths.get("src/main/resources/menu_final.txt");
BinaryData fileData = BinaryData.fromFile(filePath);
FileDetails fileDetails = new FileDetails(fileData, "menu_final.txt");


OpenAIFile openAIFile = client.uploadFile(fileDetails, FilePurpose.ASSISTANTS);

String instructions = loadInstructionsFromFile("instruction2.txt");
log.info("Application Started with Instructions: {}", instructions);

CreateToolResourcesOptions createToolResourcesOptions = new CreateToolResourcesOptions();
createToolResourcesOptions.setFileSearch(
new CreateFileSearchToolResourceOptions(
new CreateFileSearchToolResourceVectorStoreOptionsList(
Arrays.asList(new CreateFileSearchToolResourceVectorStoreOptions(Arrays.asList(openAIFile.getId()))))));

return client.createAssistant(
new AssistantCreationOptions(model)
.setName("영양사")
.setInstructions(instructions)
.setTools(Arrays.asList(new FileSearchToolDefinition()))
.setToolResources(createToolResourcesOptions)
);
}

private String loadInstructionsFromFile(String filePath) throws IOException {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(
Objects.requireNonNull(this.getClass().getClassLoader().getResourceAsStream(filePath)), StandardCharsets.UTF_8))) {
return reader.lines().collect(Collectors.joining("\n"));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,61 +1,92 @@
package com.mju.capstone.recommend.repository;

import static org.springframework.http.MediaType.APPLICATION_JSON;
import static com.azure.ai.openai.assistants.models.MessageRole.USER;

import com.mju.capstone.global.exception.BusinessException;
import com.mju.capstone.global.response.message.ErrorMessage;
import com.azure.ai.openai.assistants.AssistantsClient;
import com.azure.ai.openai.assistants.models.Assistant;
import com.azure.ai.openai.assistants.models.AssistantThread;
import com.azure.ai.openai.assistants.models.AssistantThreadCreationOptions;
import com.azure.ai.openai.assistants.models.CreateRunOptions;
import com.azure.ai.openai.assistants.models.MessageContent;
import com.azure.ai.openai.assistants.models.MessageTextContent;
import com.azure.ai.openai.assistants.models.PageableList;
import com.azure.ai.openai.assistants.models.RunStatus;
import com.azure.ai.openai.assistants.models.ThreadMessage;
import com.azure.ai.openai.assistants.models.ThreadMessageOptions;
import com.azure.ai.openai.assistants.models.ThreadRun;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mju.capstone.recommend.domain.GptManager;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.mju.capstone.recommend.dto.response.Menu;
import java.util.ArrayList;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.PropertySource;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Repository;
import org.springframework.web.client.RestTemplate;

@Repository
@PropertySource("classpath:application.yml")
@RequiredArgsConstructor
@Slf4j
public class GptManagerImpl implements GptManager {

private final RestTemplate restTemplate;
private final AssistantsClient client;
private final Assistant assistant;

public List<Menu> sendOpenAIRequest(String messageContent) {
log.info("Processing message: {}", messageContent);

AssistantThread thread = createAssistantThread();
sendMessageToThread(thread.getId(), messageContent);

@Value("${python-server.url}")
private String server_url;
List<Menu> result;

public GptManagerImpl(RestTemplate restTemplate) {
this.restTemplate = restTemplate;
try {
result = getGptResponse(thread.getId());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return result;
}

public List<Menu> sendOpenAIRequest(String request) {
private AssistantThread createAssistantThread() {
return client.createThread(new AssistantThreadCreationOptions());
}

String url = server_url;
Map<String, String> requestBody = new HashMap<>();
requestBody.put("content", request);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(APPLICATION_JSON);
private void sendMessageToThread(String threadId, String messageContent) {
client.createMessage(threadId, new ThreadMessageOptions(USER, messageContent));
}

HttpEntity<Map<String, String>> requestHttpEntity = new HttpEntity<>(requestBody, headers);
private List<Menu> getGptResponse(String threadId) throws InterruptedException {
ThreadRun run = client.createRun(threadId, new CreateRunOptions(assistant.getId()));

// ResponseModel response= restTemplate.postForObject(url, requestHttpEntity, ResponseModel.class);
do {
run = client.getRun(run.getThreadId(), run.getId());
Thread.sleep(500);
} while (run.getStatus() == RunStatus.QUEUED || run.getStatus() == RunStatus.IN_PROGRESS);

ResponseEntity<List<Menu>> responseEntity =
restTemplate.exchange(url,HttpMethod.POST,requestHttpEntity,new ParameterizedTypeReference<List<Menu>>() {});
return extractMessagesFromResponse(client.listMessages(run.getThreadId()));
}

log.info(responseEntity.getBody().toString());
List<Menu> recommendResponse = responseEntity.getBody();
private List<Menu> extractMessagesFromResponse(PageableList<ThreadMessage> messages) {
List<Menu> result = new ArrayList<>();
ObjectMapper objectMapper = new ObjectMapper();
ThreadMessage threadMessage = messages.getData().getFirst();

if (recommendResponse.isEmpty() || recommendResponse == null) {
throw new BusinessException(ErrorMessage.RECOMMEND_NOT_FOUND);
for (MessageContent messageContent : threadMessage.getContent()) {
String jsonResponse = ((MessageTextContent) messageContent).getText().getValue();
log.info("Message content: {}", jsonResponse);
try {
jsonResponse = jsonResponse.replaceAll("```json", "").trim();
List<Menu> menu = objectMapper.readValue(jsonResponse, new TypeReference<>() {
});
result.addAll(menu);
} catch (Exception e) {
e.printStackTrace();
}
}
return recommendResponse;
log.info("result: {}" ,result.toString());
return result;
}
}

0 comments on commit 55ca1e4

Please sign in to comment.