From 41a070bf978bd75aaed2319497eb39ed60017789 Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 20 Apr 2024 11:21:04 +0200 Subject: [PATCH 1/3] chat() returns again --- NAMESPACE | 9 ++++----- R/chat.R | 47 +++++++++++++---------------------------------- R/zzz.R | 3 ++- 3 files changed, 19 insertions(+), 40 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 6aa45f4..9f356a4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,13 +1,10 @@ # Generated by roxygen2: do not edit by hand -S3method(as.data.frame,chat_response) S3method(as_messages,character) S3method(as_messages,default) S3method(as_messages,list) S3method(as_msg,character) -S3method(as_tibble,chat_response) -S3method(print,chat) -S3method(print,chat_tbl) +S3method(print,chat_tibble) export(as_messages) export(chat) export(models) @@ -20,7 +17,9 @@ import(stringr) import(tibble) importFrom(jsonlite,fromJSON) importFrom(purrr,list_flatten) +importFrom(purrr,list_rbind) +importFrom(purrr,map) importFrom(purrr,map2) importFrom(purrr,map_chr) -importFrom(purrr,map_dfr) importFrom(purrr,pluck) +importFrom(utils,tail) diff --git a/R/chat.R b/R/chat.R index 2ae70d2..22a3922 100644 --- a/R/chat.R +++ b/R/chat.R @@ -24,13 +24,22 @@ chat <- function(messages, model = "mistral-tiny", ..., error_call = current_env resp <- authenticate(req, error_call = error_call) |> req_mistral_perform(error_call = error_call) - class(resp) <- c("chat", class(resp)) - resp + data <- resp_body_json(resp) + + tbl_req <- list_rbind(map(messages, as_tibble)) + tbl_resp <- list_rbind(map(data$choices, \(choice) { + as_tibble(choice$message[c("role", "content")]) + })) + tbl <- list_rbind(list(tbl_req, tbl_resp)) + + class(tbl) <- c("chat_tibble", class(tbl)) + attr(tbl, "resp") <- resp + tbl } #' @export -print.chat <- function(x, ...) { - writeLines(resp_body_json(x)$choices[[1]]$message$content) +print.chat_tibble <- function(x, ...) { + writeLines(tail(x$content, 1L)) invisible(x) } @@ -49,33 +58,3 @@ req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, ..., erro ) ) } - -#' @export -as.data.frame.chat_response <- function(x, ...) { - req_messages <- x$request$body$data$messages - df_req <- map_dfr(req_messages, as.data.frame) - - df_resp <- as.data.frame( - resp_body_json(x)$choices[[1]]$message[c("role", "content")] - ) - - rbind(df_req, df_resp) -} - -#' @export -as_tibble.chat_response <- function(x, ...) { - tib <- as_tibble(as.data.frame(x, ...)) - class(tib) <- c("chat_tbl", class(x)) - tib -} - -#' @export -print.chat_tbl <- function(x, ...) { - n <- nrow(x) - - for (i in seq_len(n)) { - writeLines(cli::col_silver(cli::rule(x$role[i]))) - writeLines(x$content[i]) - } - invisible(x) -} diff --git a/R/zzz.R b/R/zzz.R index 9fd4446..aa92f2d 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -4,8 +4,9 @@ #' @import tibble #' @import stringr #' @import slap -#' @importFrom purrr map_dfr map_chr pluck map2 list_flatten +#' @importFrom purrr list_rbind map map_chr pluck map2 list_flatten #' @importFrom jsonlite fromJSON +#' @importFrom utils tail NULL mistral_base_url <- "https://api.mistral.ai" From bd9ba11329e786e969c3ffe20ad488e43b46c88c Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 20 Apr 2024 11:30:37 +0200 Subject: [PATCH 2/3] as_msg.chat_tibble --- NAMESPACE | 1 + R/messages.R | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 9f356a4..5d400b3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(as_messages,character) S3method(as_messages,default) S3method(as_messages,list) S3method(as_msg,character) +S3method(as_msg,chat_tibble) S3method(print,chat_tibble) export(as_messages) export(chat) diff --git a/R/messages.R b/R/messages.R index f5b2fd2..1d2c1bc 100644 --- a/R/messages.R +++ b/R/messages.R @@ -58,6 +58,16 @@ as_msg.character <- function(x, name, error_call = caller_env()) { ) } +#' @export +as_msg.chat_tibble <- function(x, name, error_call = caller_env()) { + map(seq_len(nrow(x)), \(i) { + list( + role = x$role[i], + content = x$content[i] + ) + }) +} + check_role <- function(name = "", error_call = caller_env()) { if (identical(name, "")) { name <- "user" From 71e8369654a37083149696ff0b3b6c27e28675be Mon Sep 17 00:00:00 2001 From: Romain Francois Date: Sat, 20 Apr 2024 11:52:22 +0200 Subject: [PATCH 3/3] chat(...) once again --- NAMESPACE | 3 --- R/chat.R | 21 +++++++++++---------- R/messages.R | 45 ++++++++++++++------------------------------- R/stream.R | 7 ++----- man/as_messages.Rd | 22 +++++++++++++++++----- man/chat.Rd | 20 +++++++++++--------- 6 files changed, 55 insertions(+), 63 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 5d400b3..4b2dd9f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,8 +1,5 @@ # Generated by roxygen2: do not edit by hand -S3method(as_messages,character) -S3method(as_messages,default) -S3method(as_messages,list) S3method(as_msg,character) S3method(as_msg,chat_tibble) S3method(print,chat_tibble) diff --git a/R/chat.R b/R/chat.R index 22a3922..9853460 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,24 +1,25 @@ #' Chat with the Mistral api #' -#' @param messages Messages -#' @param model which model to use. See [models()] for more information about which models are available -#' @inheritParams rlang::args_dots_empty +#' @param ... messages, see [as_messages()]. +#' @param model which model to use. See [models()] for more information about which models are available. #' @inheritParams rlang::args_error_context #' -#' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request -#' if this is a `dry_run` +#' @return A tibble with columns `role` and `content` with class `chat_tibble` #' #' @examples #' #' \dontrun{ -#' chat("Top 5 R packages") +#' res <- chat("What are the top 5 R packages ?") +#' res +#' +#' # use the result from a previous chat() to continue the +#' # conversation +#' chat(res, "Why do people love them so much ?") #' } #' #' @export -chat <- function(messages, model = "mistral-tiny", ..., error_call = current_env()) { - check_dots_empty(call = error_call) - - messages <- as_messages(messages) %!% "Can't convert {.arg messages} to a list of messages." +chat <- function(..., model = "mistral-tiny", error_call = current_env()) { + messages <- as_messages(..., error_call = error_call) req <- req_chat(messages, model = model, error_call = error_call) resp <- authenticate(req, error_call = error_call) |> diff --git a/R/messages.R b/R/messages.R index 1d2c1bc..9101193 100644 --- a/R/messages.R +++ b/R/messages.R @@ -1,42 +1,25 @@ #' Convert object into a messages list #' -#' @param messages object to convert to messages -#' @inheritParams rlang::args_dots_empty +#' @param ... objects to convert to messages. Each element can be: +#' a string or a result from [chat()]. #' @inheritParams rlang::args_error_context #' #' @examples +#' # unnamed string means user #' as_messages("hello") -#' as_messages(list("hello")) -#' as_messages(list(assistant = "hello", user = "hello")) #' +#' # explicit names +#' as_messages(assistant = "hello", user = "hello") +#' +#' \dontrun{ +#' res <- chat("hello") +#' +#' # add result from previous chat() +#' as_messages(res, "hello") +#' } #' @export -as_messages <- function(messages, ...) { - UseMethod("as_messages") -} - -#' @export -as_messages.default <- function(messages, ..., error_call = current_env()) { - cli_abort(c( - "No known method for objects of class {.cls {class(messages)}}.", - i = "Use as_messages() or as_messages()." - ), call = error_call) -} - -#' @export -as_messages.character <- function(messages, ..., error_call = current_env()) { - check_dots_empty(call = error_call) - check_scalar_string(messages, error_call = error_call) - check_unnamed_string(messages, error_call = error_call) - - list( - list(role = "user", content = messages) - ) -} - -#' @export -as_messages.list <- function(messages, ..., error_call = caller_env()) { - check_dots_empty(call = error_call) - +as_messages <- function(..., error_call = current_env()) { + messages <- list2(...) out <- list_flatten( map2(messages, names2(messages), as_msg, error_call = error_call) ) diff --git a/R/stream.R b/R/stream.R index 476c2e1..bf011eb 100644 --- a/R/stream.R +++ b/R/stream.R @@ -2,12 +2,9 @@ #' #' @rdname chat #' @export -stream <- function(messages, model = "mistral-tiny", ..., error_call = current_env()) { - check_dots_empty(call = error_call) - - messages <- as_messages(messages) +stream <- function(..., model = "mistral-tiny", error_call = current_env()) { + messages <- as_messages(..., error_call = error_call) req <- req_chat(messages, model, stream = TRUE, error_call = error_call) - resp <- req_perform_stream(req, callback = stream_callback, round = "line", buffer_kb = 0.01) invisible(resp) diff --git a/man/as_messages.Rd b/man/as_messages.Rd index 5af8f19..647215f 100644 --- a/man/as_messages.Rd +++ b/man/as_messages.Rd @@ -4,19 +4,31 @@ \alias{as_messages} \title{Convert object into a messages list} \usage{ -as_messages(messages, ...) +as_messages(..., error_call = current_env()) } \arguments{ -\item{messages}{object to convert to messages} +\item{...}{objects to convert to messages. Each element can be: +a string or a result from \code{\link[=chat]{chat()}}.} -\item{...}{These dots are for future extensions and must be empty.} +\item{error_call}{The execution environment of a currently +running function, e.g. \code{caller_env()}. The function will be +mentioned in error messages as the source of the error. See the +\code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} } \description{ Convert object into a messages list } \examples{ +# unnamed string means user as_messages("hello") -as_messages(list("hello")) -as_messages(list(assistant = "hello", user = "hello")) +# explicit names +as_messages(assistant = "hello", user = "hello") + +\dontrun{ + res <- chat("hello") + + # add result from previous chat() + as_messages(res, "hello") +} } diff --git a/man/chat.Rd b/man/chat.Rd index ea0d8e0..4708ec3 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -5,16 +5,14 @@ \alias{stream} \title{Chat with the Mistral api} \usage{ -chat(messages, model = "mistral-tiny", ..., error_call = current_env()) +chat(..., model = "mistral-tiny", error_call = current_env()) -stream(messages, model = "mistral-tiny", ..., error_call = current_env()) +stream(..., model = "mistral-tiny", error_call = current_env()) } \arguments{ -\item{messages}{Messages} +\item{...}{messages, see \code{\link[=as_messages]{as_messages()}}.} -\item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} - -\item{...}{These dots are for future extensions and must be empty.} +\item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available.} \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be @@ -22,8 +20,7 @@ mentioned in error messages as the source of the error. See the \code{call} argument of \code{\link[rlang:abort]{abort()}} for more information.} } \value{ -A tibble with columns \code{role} and \code{content} with class \code{chat_tibble} or a request -if this is a \code{dry_run} +A tibble with columns \code{role} and \code{content} with class \code{chat_tibble} } \description{ Chat with the Mistral api @@ -33,7 +30,12 @@ stream \examples{ \dontrun{ - chat("Top 5 R packages") + res <- chat("What are the top 5 R packages ?") + res + + # use the result from a previous chat() to continue the + # conversation + chat(res, "Why do people love them so much ?") } }