Skip to content

Commit

Permalink
Revert "Remove extra information from jinja."
Browse files Browse the repository at this point in the history
This reverts commit 15e8fba.
  • Loading branch information
manyoso committed Nov 4, 2024
1 parent 15e8fba commit 50ea10e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
61 changes: 59 additions & 2 deletions gpt4all-chat/src/jinja_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@

using namespace std::literals::string_view_literals;


JinjaResultInfo::~JinjaResultInfo() = default;

const JinjaFieldMap<ResultInfo> JinjaResultInfo::s_fields = {
{ "collection", [](auto &s) { return s.collection.toStdString(); } },
{ "path", [](auto &s) { return s.path .toStdString(); } },
{ "file", [](auto &s) { return s.file .toStdString(); } },
{ "title", [](auto &s) { return s.title .toStdString(); } },
{ "author", [](auto &s) { return s.author .toStdString(); } },
{ "date", [](auto &s) { return s.date .toStdString(); } },
{ "text", [](auto &s) { return s.text .toStdString(); } },
{ "page", [](auto &s) { return s.page; } },
{ "fileUri", [](auto &s) { return s.fileUri() .toStdString(); } },
};

JinjaPromptAttachment::~JinjaPromptAttachment() = default;

const JinjaFieldMap<PromptAttachment> JinjaPromptAttachment::s_fields = {
{ "url", [](auto &s) { return s.url.toString() .toStdString(); } },
{ "file", [](auto &s) { return s.file() .toStdString(); } },
{ "processedContent", [](auto &s) { return s.processedContent().toStdString(); } },
};

std::vector<std::string> JinjaMessage::GetKeys() const
{
std::vector<std::string> result;
Expand All @@ -26,15 +49,37 @@ auto JinjaMessage::keys() const -> const std::unordered_set<std::string_view> &
{
static const std::unordered_set<std::string_view> baseKeys
{ "role", "content" };
return baseKeys;
static const std::unordered_set<std::string_view> userKeys
{ "role", "content", "sources", "prompt_attachments" };
switch (m_item->type()) {
using enum ChatItem::Type;
case System:
case Response:
return baseKeys;
case Prompt:
return userKeys;
}
Q_UNREACHABLE();
}

bool operator==(const JinjaMessage &a, const JinjaMessage &b)
{
if (a.m_item == b.m_item)
return true;
const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item);
return ia.type() == ib.type() && ia.value == ib.value;
auto type = ia.type();
if (type != ib.type() || ia.value != ib.value)
return false;

switch (type) {
using enum ChatItem::Type;
case System:
case Response:
return true;
case Prompt:
return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments;
}
Q_UNREACHABLE();
}

const JinjaFieldMap<ChatItem> JinjaMessage::s_fields = {
Expand All @@ -48,4 +93,16 @@ const JinjaFieldMap<ChatItem> JinjaMessage::s_fields = {
Q_UNREACHABLE();
} },
{ "content", [](auto &i) { return i.value.toStdString(); } },
{ "sources", [](auto &i) {
auto sources = i.sources | views::transform([](auto &r) {
return jinja2::GenericMap([map = std::make_shared<JinjaResultInfo>(r)] { return map.get(); });
});
return jinja2::ValuesList(sources.begin(), sources.end());
} },
{ "prompt_attachments", [](auto &i) {
auto attachments = i.promptAttachments | views::transform([](auto &pa) {
return jinja2::GenericMap([map = std::make_shared<JinjaPromptAttachment>(pa)] { return map.get(); });
});
return jinja2::ValuesList(attachments.begin(), attachments.end());
} },
};
20 changes: 20 additions & 0 deletions gpt4all-chat/src/jinja_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

namespace views = std::views;


template <typename T>
using JinjaFieldMap = std::unordered_map<std::string_view, std::function<jinja2::Value (const T &)>>;

Expand Down Expand Up @@ -62,6 +63,25 @@ class JinjaResultInfo : public JinjaHelper<JinjaResultInfo> {
friend class JinjaHelper<JinjaResultInfo>;
};

class JinjaPromptAttachment : public JinjaHelper<JinjaPromptAttachment> {
public:
explicit JinjaPromptAttachment(const PromptAttachment &attachment) noexcept
: m_attachment(&attachment) {}

~JinjaPromptAttachment() override;

const PromptAttachment &value() const { return *m_attachment; }

friend bool operator==(const JinjaPromptAttachment &a, const JinjaPromptAttachment &b)
{ return a.m_attachment == b.m_attachment || *a.m_attachment == *b.m_attachment; }

private:
static const JinjaFieldMap<PromptAttachment> s_fields;
const PromptAttachment *m_attachment;

friend class JinjaHelper<JinjaPromptAttachment>;
};

class JinjaMessage : public JinjaHelper<JinjaMessage> {
public:
explicit JinjaMessage(const ChatItem &item) noexcept
Expand Down

0 comments on commit 50ea10e

Please sign in to comment.