Skip to content

Commit

Permalink
feat: Add prefetch for paginator
Browse files Browse the repository at this point in the history
Signed-off-by: zychen5186 <brianchen5197@gmail.com>
  • Loading branch information
zychen5186 committed May 1, 2024
1 parent 88a2f50 commit 954c84c
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 78 deletions.
10 changes: 4 additions & 6 deletions flytectl/cmd/get/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,15 @@ func getExecutionFunc(ctx context.Context, args []string, cmdCtx cmdCore.Command
return adminPrinter.Print(config.GetConfig().MustOutputFormat(), executionColumns,
ExecutionToProtoMessages(executions)...)
}
if config.GetConfig().Interactive {
err := bubbletea.Paginator(executionColumns, getCallBack(ctx, cmdCtx), execution.DefaultConfig.Filter)
return err
}
executionList, err := cmdCtx.AdminFetcherExt().ListExecution(ctx, config.GetConfig().Project, config.GetConfig().Domain, execution.DefaultConfig.Filter)
if err != nil {
return err
}
logger.Infof(ctx, "Retrieved %v executions", len(executionList.Executions))

if config.GetConfig().Interactive {
bubbletea.Paginator(executionColumns, getCallBack(ctx, cmdCtx))
return nil
}

return adminPrinter.Print(config.GetConfig().MustOutputFormat(), executionColumns,
ExecutionToProtoMessages(executionList.Executions)...)
}
111 changes: 101 additions & 10 deletions flytectl/pkg/bubbletea/bubbletea_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,52 @@ package bubbletea
import (
"fmt"
"log"
"math"
"strings"

"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/paginator"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/lipgloss"
"github.com/flyteorg/flytectl/pkg/filters"
"github.com/flyteorg/flytectl/pkg/printer"
"github.com/golang/protobuf/proto"

tea "github.com/charmbracelet/bubbletea"
)

var (
spin = false
// Avoid fetching multiple times while still fetching
fetchingBackward = false
fetchingForward = false
)

type pageModel struct {
items []proto.Message
items *[]proto.Message
paginator paginator.Model
spinner spinner.Model
}

func newModel(initMsg []proto.Message) pageModel {
p := paginator.New()
p.PerPage = msgPerPage
p.SetTotalPages(len(initMsg))
p.Page = int(filter.Page) - 1
p.SetTotalPages(getLocalLastPage())

s := spinner.New()
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("56"))
s.Spinner = spinner.Points

return pageModel{
paginator: p,
items: initMsg,
spinner: s,
items: &initMsg,
}
}

func (m pageModel) Init() tea.Cmd {
return nil
return m.spinner.Tick
}

func (m pageModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
Expand All @@ -40,35 +59,107 @@ func (m pageModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case "q", "esc", "ctrl+c":
return m, tea.Quit
}
switch {
case key.Matches(msg, m.paginator.KeyMap.PrevPage):
// If previous page will be out of the range of the first batch, don't update
if m.paginator.Page == firstBatchIndex*pagePerBatch {
return m, nil
}
}
case spinner.TickMsg:
m.spinner, cmd = m.spinner.Update(msg)
return m, cmd
case newDataMsg:
if msg.fetchDirection == forward {
// Update if current page is in the range of the last batch
// i.e. if user not in last batch when finished fetching, don't update
if m.paginator.Page/pagePerBatch >= lastBatchIndex {
*m.items = append(*m.items, msg.newItems...)
lastBatchIndex++
if lastBatchIndex-firstBatchIndex >= localBatchLimit {
*m.items = (*m.items)[batchLen[firstBatchIndex]:]
firstBatchIndex++
}
}
fetchingForward = false
} else {
// Update if current page is in the range of the first batch
// i.e. if user not in first batch when finished fetching, don't update
if m.paginator.Page/pagePerBatch <= firstBatchIndex {
*m.items = append(msg.newItems, *m.items...)
firstBatchIndex--
if lastBatchIndex-firstBatchIndex >= localBatchLimit {
*m.items = (*m.items)[:len(*m.items)-batchLen[lastBatchIndex]]
lastBatchIndex--
}
}
fetchingBackward = false
}
m.paginator.SetTotalPages(getLocalLastPage())
return m, nil
}
m.paginator, cmd = m.paginator.Update(msg)
preFetchBatch(&m)

m.paginator, _ = m.paginator.Update(msg)
switch msg := msg.(type) {
case tea.KeyMsg:
switch {
case key.Matches(msg, m.paginator.KeyMap.NextPage):
if (m.paginator.Page >= (lastBatchIndex+1)*pagePerBatch-prefetchThreshold) && !fetchingForward {
// If no more data, don't fetch again (won't show spinner)
value, ok := batchLen[lastBatchIndex+1]
if !ok || value != 0 {
fetchingForward = true
cmd = fetchDataCmd(lastBatchIndex+1, forward)
}
}
case key.Matches(msg, m.paginator.KeyMap.PrevPage):
if (m.paginator.Page <= firstBatchIndex*pagePerBatch+prefetchThreshold) && (firstBatchIndex > 0) && !fetchingBackward {
fetchingBackward = true
cmd = fetchDataCmd(firstBatchIndex-1, backward)
}
}
}

return m, cmd
}

func (m pageModel) View() string {
var b strings.Builder
table, err := getTable(&m)
if err != nil {
return ""
return "Error rendering table"
}
b.WriteString(table)
b.WriteString(fmt.Sprintf(" PAGE - %d\n", m.paginator.Page+1))
b.WriteString(fmt.Sprintf(" PAGE - %d ", m.paginator.Page+1))
if spin {
b.WriteString(fmt.Sprintf("%s%s", m.spinner.View(), " Loading new pages..."))
}
b.WriteString("\n\n h/l ←/→ page • q: quit\n")

return b.String()
}

func Paginator(_listHeader []printer.Column, _callback DataCallback) {
func Paginator(_listHeader []printer.Column, _callback DataCallback, _filter filters.Filters) error {
listHeader = _listHeader
callback = _callback
filter = _filter
filter.Page = int32(_max(int(filter.Page), 1))
firstBatchIndex = (int(filter.Page) - 1) / pagePerBatch
lastBatchIndex = firstBatchIndex

var msg []proto.Message
for i := firstBatchIndex; i < lastBatchIndex+1; i++ {
msg = append(msg, getMessageList(i)...)
newMessages := getMessageList(i)
if int(filter.Page)-(firstBatchIndex*pagePerBatch) > int(math.Ceil(float64(len(newMessages))/msgPerPage)) {
return fmt.Errorf("the specified page has no data, please enter a valid page number")
}
msg = append(msg, newMessages...)
}

p := tea.NewProgram(newModel(msg))
if _, err := p.Run(); err != nil {
log.Fatal(err)
}

return nil
}
138 changes: 76 additions & 62 deletions flytectl/pkg/bubbletea/bubbletea_pagination_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"encoding/json"
"fmt"
"strings"
"sync"

tea "github.com/charmbracelet/bubbletea"
"github.com/flyteorg/flytectl/pkg/filters"
"github.com/flyteorg/flytectl/pkg/printer"

Expand All @@ -15,32 +17,31 @@ import (

type DataCallback func(filter filters.Filters) []proto.Message

type PrintableProto struct{ proto.Message }
type printTableProto struct{ proto.Message }

const (
msgPerBatch = 100 // Please set msgPerBatch as a multiple of msgPerPage
msgPerPage = 10
pagePerBatch = msgPerBatch / msgPerPage
msgPerBatch = 100 // Please set msgPerBatch as a multiple of msgPerPage
msgPerPage = 10
pagePerBatch = msgPerBatch / msgPerPage
prefetchThreshold = pagePerBatch - 1
localBatchLimit = 10 // Please set localBatchLimit at least 2
)

var (
// Used for indexing local stored rows
localPageIndex int
// Recording batch index fetched from admin
firstBatchIndex int32 = 1
lastBatchIndex int32 = 10
batchLen = make(map[int32]int)
// Callback function used to fetch data from the module that called bubbletea pagination.
callback DataCallback
// The header of the table
callback DataCallback
listHeader []printer.Column

marshaller = jsonpb.Marshaler{
Indent: "\t",
}
filter filters.Filters
// Record the index of the first and last batch that is in cache
firstBatchIndex int
lastBatchIndex int
batchLen = make(map[int]int)
// Avoid fetching back and forward at the same time
mutex sync.Mutex
)

func (p PrintableProto) MarshalJSON() ([]byte, error) {
func (p printTableProto) MarshalJSON() ([]byte, error) {
marshaller := jsonpb.Marshaler{Indent: "\t"}
buf := new(bytes.Buffer)
err := marshaller.Marshal(buf, p.Message)
if err != nil {
Expand All @@ -49,28 +50,35 @@ func (p PrintableProto) MarshalJSON() ([]byte, error) {
return buf.Bytes(), nil
}

func min(a, b int) int {
func _max(a, b int) int {
if a > b {
return a
}
return b
}

func _min(a, b int) int {
if a < b {
return a
}
return b
}

func getSliceBounds(idx int, length int) (start int, end int) {
start = idx * msgPerPage
end = min(idx*msgPerPage+msgPerPage, length)
func getSliceBounds(m *pageModel) (start int, end int) {
start = (m.paginator.Page - firstBatchIndex*pagePerBatch) * msgPerPage
end = _min(start+msgPerPage, len(*m.items))
return start, end
}

func getTable(m *pageModel) (string, error) {
start, end := getSliceBounds(localPageIndex, len(m.items))
curShowMessage := m.items[start:end]
printableMessages := make([]*PrintableProto, 0, len(curShowMessage))
start, end := getSliceBounds(m)
curShowMessage := (*m.items)[start:end]
printTableMessages := make([]*printTableProto, 0, len(curShowMessage))
for _, m := range curShowMessage {
printableMessages = append(printableMessages, &PrintableProto{Message: m})
printTableMessages = append(printTableMessages, &printTableProto{Message: m})
}

jsonRows, err := json.Marshal(printableMessages)
jsonRows, err := json.Marshal(printTableMessages)
if err != nil {
return "", fmt.Errorf("failed to marshal proto messages")
}
Expand All @@ -84,53 +92,59 @@ func getTable(m *pageModel) (string, error) {
return buf.String(), nil
}

func getMessageList(batchIndex int32) []proto.Message {
func getMessageList(batchIndex int) []proto.Message {
mutex.Lock()
spin = true
defer func() {
spin = false
mutex.Unlock()
}()

msg := callback(filters.Filters{
Limit: msgPerBatch,
Page: batchIndex,
SortBy: "created_at",
Asc: false,
Page: int32(batchIndex + 1),
SortBy: filter.SortBy,
Asc: filter.Asc,
})

batchLen[batchIndex] = len(msg)

return msg
}

func countTotalPages() int {
sum := 0
for _, l := range batchLen {
sum += l
}
return sum
type direction int

const (
forward direction = iota
backward
)

type newDataMsg struct {
newItems []proto.Message
batchIndex int
fetchDirection direction
}

// Only (lastBatchIndex-firstBatchIndex)*msgPerBatch of rows are stored in local memory.
// When user tries to get rows out of this range, this function will be triggered.
func preFetchBatch(m *pageModel) {
localPageIndex = m.paginator.Page - int(firstBatchIndex-1)*pagePerBatch

// Triggers when user is at the last local page
if localPageIndex+1 == len(m.items)/msgPerPage {
newMessages := getMessageList(lastBatchIndex + 1)
m.paginator.SetTotalPages(countTotalPages())
if len(newMessages) != 0 {
lastBatchIndex++
m.items = append(m.items, newMessages...)
m.items = m.items[batchLen[firstBatchIndex]:] // delete the msgs in the "firstBatchIndex" batch
localPageIndex -= batchLen[firstBatchIndex] / msgPerPage
firstBatchIndex++
func fetchDataCmd(batchIndex int, fetchDirection direction) tea.Cmd {
return func() tea.Msg {
msg := newDataMsg{
newItems: getMessageList(batchIndex),
batchIndex: batchIndex,
fetchDirection: fetchDirection,
}
return
return msg
}
// Triggers when user is at the first local page
if localPageIndex == 0 && firstBatchIndex > 1 {
newMessages := getMessageList(firstBatchIndex - 1)
m.paginator.SetTotalPages(countTotalPages())
firstBatchIndex--
m.items = append(newMessages, m.items...)
m.items = m.items[:len(m.items)-batchLen[lastBatchIndex]] // delete the msgs in the "lastBatchIndex" batch
localPageIndex += batchLen[firstBatchIndex] / msgPerPage
lastBatchIndex--
return
}

func getLocalLastPage() int {
sum := 0
for i := 0; i < lastBatchIndex+1; i++ {
length, ok := batchLen[i]
if ok {
sum += length
} else {
sum += msgPerBatch
}
}
return sum
}

0 comments on commit 954c84c

Please sign in to comment.