Skip to content

Commit

Permalink
Merge pull request #29 from tadascience/chat_tibble
Browse files Browse the repository at this point in the history
`chat()` takes `...` and return a `<chat_tibble>` once again
  • Loading branch information
romainfrancois authored Apr 20, 2024
2 parents 6c0329a + 71e8369 commit ef8f30f
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 103 deletions.
13 changes: 5 additions & 8 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(as.data.frame,chat_response)
S3method(as_messages,character)
S3method(as_messages,default)
S3method(as_messages,list)
S3method(as_msg,character)
S3method(as_tibble,chat_response)
S3method(print,chat)
S3method(print,chat_tbl)
S3method(as_msg,chat_tibble)
S3method(print,chat_tibble)
export(as_messages)
export(chat)
export(models)
Expand All @@ -20,7 +15,9 @@ import(stringr)
import(tibble)
importFrom(jsonlite,fromJSON)
importFrom(purrr,list_flatten)
importFrom(purrr,list_rbind)
importFrom(purrr,map)
importFrom(purrr,map2)
importFrom(purrr,map_chr)
importFrom(purrr,map_dfr)
importFrom(purrr,pluck)
importFrom(utils,tail)
68 changes: 24 additions & 44 deletions R/chat.R
Original file line number Diff line number Diff line change
@@ -1,36 +1,46 @@
#' Chat with the Mistral api
#'
#' @param messages Messages
#' @param model which model to use. See [models()] for more information about which models are available
#' @inheritParams rlang::args_dots_empty
#' @param ... messages, see [as_messages()].
#' @param model which model to use. See [models()] for more information about which models are available.
#' @inheritParams rlang::args_error_context
#'
#' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request
#' if this is a `dry_run`
#' @return A tibble with columns `role` and `content` with class `chat_tibble`
#'
#' @examples
#'
#' \dontrun{
#' chat("Top 5 R packages")
#' res <- chat("What are the top 5 R packages ?")
#' res
#'
#' # use the result from a previous chat() to continue the
#' # conversation
#' chat(res, "Why do people love them so much ?")
#' }
#'
#' @export
chat <- function(messages, model = "mistral-tiny", ..., error_call = current_env()) {
check_dots_empty(call = error_call)

messages <- as_messages(messages) %!% "Can't convert {.arg messages} to a list of messages."
chat <- function(..., model = "mistral-tiny", error_call = current_env()) {
messages <- as_messages(..., error_call = error_call)

req <- req_chat(messages, model = model, error_call = error_call)
resp <- authenticate(req, error_call = error_call) |>
req_mistral_perform(error_call = error_call)

class(resp) <- c("chat", class(resp))
resp
data <- resp_body_json(resp)

tbl_req <- list_rbind(map(messages, as_tibble))
tbl_resp <- list_rbind(map(data$choices, \(choice) {
as_tibble(choice$message[c("role", "content")])
}))
tbl <- list_rbind(list(tbl_req, tbl_resp))

class(tbl) <- c("chat_tibble", class(tbl))
attr(tbl, "resp") <- resp
tbl
}

#' @export
print.chat <- function(x, ...) {
writeLines(resp_body_json(x)$choices[[1]]$message$content)
print.chat_tibble <- function(x, ...) {
writeLines(tail(x$content, 1L))
invisible(x)
}

Expand All @@ -49,33 +59,3 @@ req_chat <- function(messages, model = "mistral-tiny", stream = FALSE, ..., erro
)
)
}

#' @export
as.data.frame.chat_response <- function(x, ...) {
req_messages <- x$request$body$data$messages
df_req <- map_dfr(req_messages, as.data.frame)

df_resp <- as.data.frame(
resp_body_json(x)$choices[[1]]$message[c("role", "content")]
)

rbind(df_req, df_resp)
}

#' @export
as_tibble.chat_response <- function(x, ...) {
tib <- as_tibble(as.data.frame(x, ...))
class(tib) <- c("chat_tbl", class(x))
tib
}

#' @export
print.chat_tbl <- function(x, ...) {
n <- nrow(x)

for (i in seq_len(n)) {
writeLines(cli::col_silver(cli::rule(x$role[i])))
writeLines(x$content[i])
}
invisible(x)
}
55 changes: 24 additions & 31 deletions R/messages.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,25 @@
#' Convert object into a messages list
#'
#' @param messages object to convert to messages
#' @inheritParams rlang::args_dots_empty
#' @param ... objects to convert to messages. Each element can be:
#' a string or a result from [chat()].
#' @inheritParams rlang::args_error_context
#'
#' @examples
#' # unnamed string means user
#' as_messages("hello")
#' as_messages(list("hello"))
#' as_messages(list(assistant = "hello", user = "hello"))
#'
#' # explicit names
#' as_messages(assistant = "hello", user = "hello")
#'
#' \dontrun{
#' res <- chat("hello")
#'
#' # add result from previous chat()
#' as_messages(res, "hello")
#' }
#' @export
as_messages <- function(messages, ...) {
UseMethod("as_messages")
}

#' @export
as_messages.default <- function(messages, ..., error_call = current_env()) {
cli_abort(c(
"No known method for objects of class {.cls {class(messages)}}.",
i = "Use as_messages(<character>) or as_messages(<list>)."
), call = error_call)
}

#' @export
as_messages.character <- function(messages, ..., error_call = current_env()) {
check_dots_empty(call = error_call)
check_scalar_string(messages, error_call = error_call)
check_unnamed_string(messages, error_call = error_call)

list(
list(role = "user", content = messages)
)
}

#' @export
as_messages.list <- function(messages, ..., error_call = caller_env()) {
check_dots_empty(call = error_call)

as_messages <- function(..., error_call = current_env()) {
messages <- list2(...)
out <- list_flatten(
map2(messages, names2(messages), as_msg, error_call = error_call)
)
Expand All @@ -58,6 +41,16 @@ as_msg.character <- function(x, name, error_call = caller_env()) {
)
}

#' @export
as_msg.chat_tibble <- function(x, name, error_call = caller_env()) {
map(seq_len(nrow(x)), \(i) {
list(
role = x$role[i],
content = x$content[i]
)
})
}

check_role <- function(name = "", error_call = caller_env()) {
if (identical(name, "")) {
name <- "user"
Expand Down
7 changes: 2 additions & 5 deletions R/stream.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
#'
#' @rdname chat
#' @export
stream <- function(messages, model = "mistral-tiny", ..., error_call = current_env()) {
check_dots_empty(call = error_call)

messages <- as_messages(messages)
stream <- function(..., model = "mistral-tiny", error_call = current_env()) {
messages <- as_messages(..., error_call = error_call)
req <- req_chat(messages, model, stream = TRUE, error_call = error_call)

resp <- req_perform_stream(req, callback = stream_callback, round = "line", buffer_kb = 0.01)

invisible(resp)
Expand Down
3 changes: 2 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#' @import tibble
#' @import stringr
#' @import slap
#' @importFrom purrr map_dfr map_chr pluck map2 list_flatten
#' @importFrom purrr list_rbind map map_chr pluck map2 list_flatten
#' @importFrom jsonlite fromJSON
#' @importFrom utils tail
NULL

mistral_base_url <- "https://api.mistral.ai"
Expand Down
22 changes: 17 additions & 5 deletions man/as_messages.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ef8f30f

Please sign in to comment.