From af39608d85829ab9414075eac86252054f6c0226 Mon Sep 17 00:00:00 2001 From: Wang Yiwen <121547057+yiwen101@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:18:42 +0800 Subject: [PATCH 1/2] Add payload size limit --- lib/cadet_web/controllers/chat_controller.ex | 79 +++++++++++-------- lib/cadet_web/views/chat_view.ex | 8 +- .../controllers/chat_controller_test.exs | 18 +++++ 3 files changed, 70 insertions(+), 35 deletions(-) diff --git a/lib/cadet_web/controllers/chat_controller.ex b/lib/cadet_web/controllers/chat_controller.ex index e39271ea4..6aa49aca1 100644 --- a/lib/cadet_web/controllers/chat_controller.ex +++ b/lib/cadet_web/controllers/chat_controller.ex @@ -6,6 +6,7 @@ defmodule CadetWeb.ChatController do use PhoenixSwagger alias Cadet.Chatbot.{Conversation, LlmConversations} + @max_content_size 1000 def init_chat(conn, %{"section" => section, "initialContext" => initialContext}) do user = conn.assigns.current_user @@ -21,7 +22,8 @@ defmodule CadetWeb.ChatController do "conversation_init.json", %{ conversation_id: conversation.id, - last_message: conversation.messages |> List.last() + last_message: conversation.messages |> List.last(), + max_content_size: @max_content_size } ) @@ -51,45 +53,54 @@ defmodule CadetWeb.ChatController do response(200, "OK") response(400, "Missing or invalid parameter(s)") response(401, "Unauthorized") + response(421, "Message exceeds the maximum allowed length") response(500, "When OpenAI API returns an error") end def chat(conn, %{"conversationId" => conversation_id, "message" => user_message}) do - user = conn.assigns.current_user + if String.length(user_message) > @max_content_size do + send_resp( + conn, + :unprocessable_entity, + "Message exceeds the maximum allowed length of #{@max_content_size}" + ) + else + user = conn.assigns.current_user + + with {:ok, conversation} <- + LlmConversations.get_conversation_for_user(user.id, conversation_id), + {:ok, updated_conversation} <- + LlmConversations.add_message(conversation, "user", user_message), + payload <- generate_payload(updated_conversation) do + case OpenAI.chat_completion(model: "gpt-4", messages: payload) do + {:ok, result_map} -> + choices = Map.get(result_map, :choices, []) + bot_message = Enum.at(choices, 0)["message"]["content"] + + case LlmConversations.add_message(updated_conversation, "assistant", bot_message) do + {:ok, _} -> + render(conn, "conversation.json", %{ + conversation_id: conversation_id, + response: bot_message + }) + + {:error, error_message} -> + send_resp(conn, 500, error_message) + end + + {:error, reason} -> + error_message = reason["error"]["message"] + IO.puts("Error message from openAI response: #{error_message}") + LlmConversations.add_error_message(updated_conversation) + send_resp(conn, 500, error_message) + end + else + {:error, {:not_found, error_message}} -> + send_resp(conn, :not_found, error_message) - with {:ok, conversation} <- - LlmConversations.get_conversation_for_user(user.id, conversation_id), - {:ok, updated_conversation} <- - LlmConversations.add_message(conversation, "user", user_message), - payload <- generate_payload(updated_conversation) do - case OpenAI.chat_completion(model: "gpt-4", messages: payload) do - {:ok, result_map} -> - choices = Map.get(result_map, :choices, []) - bot_message = Enum.at(choices, 0)["message"]["content"] - - case LlmConversations.add_message(updated_conversation, "assistant", bot_message) do - {:ok, _} -> - render(conn, "conversation.json", %{ - conversation_id: conversation_id, - response: bot_message - }) - - {:error, error_message} -> - send_resp(conn, 500, error_message) - end - - {:error, reason} -> - error_message = reason["error"]["message"] - IO.puts("Error message from openAI response: #{error_message}") - LlmConversations.add_error_message(updated_conversation) + {:error, error_message} -> send_resp(conn, 500, error_message) end - else - {:error, {:not_found, error_message}} -> - send_resp(conn, :not_found, error_message) - - {:error, error_message} -> - send_resp(conn, 500, error_message) end end @@ -107,4 +118,6 @@ defmodule CadetWeb.ChatController do conversation.prepend_context ++ messages_payload end + + def max_content_length, do: @max_content_size end diff --git a/lib/cadet_web/views/chat_view.ex b/lib/cadet_web/views/chat_view.ex index b7db0bd2b..8dc72727a 100644 --- a/lib/cadet_web/views/chat_view.ex +++ b/lib/cadet_web/views/chat_view.ex @@ -1,8 +1,12 @@ defmodule CadetWeb.ChatView do use CadetWeb, :view - def render("conversation_init.json", %{conversation_id: id, last_message: last}) do - %{conversationId: id, response: last} + def render("conversation_init.json", %{ + conversation_id: id, + last_message: last, + max_content_size: size + }) do + %{conversationId: id, response: last, maxContentSize: size} end def render("conversation.json", %{conversation_id: id, response: response}) do diff --git a/test/cadet_web/controllers/chat_controller_test.exs b/test/cadet_web/controllers/chat_controller_test.exs index e5f63f164..9fd6cf996 100644 --- a/test/cadet_web/controllers/chat_controller_test.exs +++ b/test/cadet_web/controllers/chat_controller_test.exs @@ -76,6 +76,24 @@ defmodule CadetWeb.ChatControllerTest do end end + @tag authenticate: :student + @tag requires_setup: true + test "The content length is too long", + %{conn: conn, conversation_id: conversation_id} do + assert conversation_id != nil + max_message_length = ChatController.max_content_length() + message_exceed_length = String.duplicate("a", max_message_length + 1) + + conn = + post(conn, "/v2/chats/#{conversation_id}/message", %{ + "conversation_id" => conversation_id, + "message" => "#{message_exceed_length}" + }) + + assert response(conn, 422) == + "Message exceeds the maximum allowed length of #{max_message_length}" + end + @tag authenticate: :student test "invalid conversation id", %{conn: conn} do conversation_id = "-1" From cbd76332f7deafeaefbb5bac83e325d208c780a0 Mon Sep 17 00:00:00 2001 From: Wang Yiwen <121547057+yiwen101@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:19:06 +0800 Subject: [PATCH 2/2] Improve code quality --- lib/cadet_web/controllers/chat_controller.ex | 82 +++++++++---------- .../controllers/chat_controller_test.exs | 19 ++++- 2 files changed, 59 insertions(+), 42 deletions(-) diff --git a/lib/cadet_web/controllers/chat_controller.ex b/lib/cadet_web/controllers/chat_controller.ex index 6aa49aca1..030c7cd3a 100644 --- a/lib/cadet_web/controllers/chat_controller.ex +++ b/lib/cadet_web/controllers/chat_controller.ex @@ -53,54 +53,54 @@ defmodule CadetWeb.ChatController do response(200, "OK") response(400, "Missing or invalid parameter(s)") response(401, "Unauthorized") - response(421, "Message exceeds the maximum allowed length") + response(422, "Message exceeds the maximum allowed length") response(500, "When OpenAI API returns an error") end def chat(conn, %{"conversationId" => conversation_id, "message" => user_message}) do - if String.length(user_message) > @max_content_size do - send_resp( - conn, - :unprocessable_entity, - "Message exceeds the maximum allowed length of #{@max_content_size}" - ) - else - user = conn.assigns.current_user - - with {:ok, conversation} <- - LlmConversations.get_conversation_for_user(user.id, conversation_id), - {:ok, updated_conversation} <- - LlmConversations.add_message(conversation, "user", user_message), - payload <- generate_payload(updated_conversation) do - case OpenAI.chat_completion(model: "gpt-4", messages: payload) do - {:ok, result_map} -> - choices = Map.get(result_map, :choices, []) - bot_message = Enum.at(choices, 0)["message"]["content"] - - case LlmConversations.add_message(updated_conversation, "assistant", bot_message) do - {:ok, _} -> - render(conn, "conversation.json", %{ - conversation_id: conversation_id, - response: bot_message - }) - - {:error, error_message} -> - send_resp(conn, 500, error_message) - end - - {:error, reason} -> - error_message = reason["error"]["message"] - IO.puts("Error message from openAI response: #{error_message}") - LlmConversations.add_error_message(updated_conversation) - send_resp(conn, 500, error_message) - end - else - {:error, {:not_found, error_message}} -> - send_resp(conn, :not_found, error_message) + user = conn.assigns.current_user - {:error, error_message} -> + with true <- String.length(user_message) <= @max_content_size || {:error, :message_too_long}, + {:ok, conversation} <- + LlmConversations.get_conversation_for_user(user.id, conversation_id), + {:ok, updated_conversation} <- + LlmConversations.add_message(conversation, "user", user_message), + payload <- generate_payload(updated_conversation) do + case OpenAI.chat_completion(model: "gpt-4", messages: payload) do + {:ok, result_map} -> + choices = Map.get(result_map, :choices, []) + bot_message = Enum.at(choices, 0)["message"]["content"] + + case LlmConversations.add_message(updated_conversation, "assistant", bot_message) do + {:ok, _} -> + render(conn, "conversation.json", %{ + conversation_id: conversation_id, + response: bot_message + }) + + {:error, error_message} -> + send_resp(conn, 500, error_message) + end + + {:error, reason} -> + error_message = reason["error"]["message"] + IO.puts("Error message from openAI response: #{error_message}") + LlmConversations.add_error_message(updated_conversation) send_resp(conn, 500, error_message) end + else + {:error, :message_too_long} -> + send_resp( + conn, + :unprocessable_entity, + "Message exceeds the maximum allowed length of #{@max_content_size}" + ) + + {:error, {:not_found, error_message}} -> + send_resp(conn, :not_found, error_message) + + {:error, error_message} -> + send_resp(conn, 500, error_message) end end diff --git a/test/cadet_web/controllers/chat_controller_test.exs b/test/cadet_web/controllers/chat_controller_test.exs index 9fd6cf996..b8c0c6808 100644 --- a/test/cadet_web/controllers/chat_controller_test.exs +++ b/test/cadet_web/controllers/chat_controller_test.exs @@ -90,10 +90,27 @@ defmodule CadetWeb.ChatControllerTest do "message" => "#{message_exceed_length}" }) - assert response(conn, 422) == + assert response(conn, :unprocessable_entity) == "Message exceeds the maximum allowed length of #{max_message_length}" end + @tag authenticate: :student + @tag requires_setup: true + test "The content length less than the maximum allowed length but conversation belongs to another user", + %{conn: conn, conversation_id: conversation_id} do + assert conversation_id != nil + max_message_length = ChatController.max_content_length() + message_exceed_length = String.duplicate("a", max_message_length) + + conn = + post(conn, "/v2/chats/#{conversation_id}/message", %{ + "conversation_id" => conversation_id, + "message" => "#{message_exceed_length}" + }) + + assert response(conn, :not_found) == "Conversation not found" + end + @tag authenticate: :student test "invalid conversation id", %{conn: conn} do conversation_id = "-1"