Skip to content

Commit

Permalink
support dump and restore binary backup (gorse-io#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored and Anonymous committed Nov 12, 2024
1 parent e32a4a1 commit d7c351a
Show file tree
Hide file tree
Showing 5 changed files with 1,222 additions and 400 deletions.
309 changes: 309 additions & 0 deletions master/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package master
import (
"bufio"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -45,10 +46,13 @@ import (
"github.com/zhenghaoz/gorse/config"
"github.com/zhenghaoz/gorse/model/click"
"github.com/zhenghaoz/gorse/model/ranking"
"github.com/zhenghaoz/gorse/protocol"
"github.com/zhenghaoz/gorse/server"
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)

func (m *Master) CreateWebService() {
Expand Down Expand Up @@ -225,6 +229,8 @@ func (m *Master) StartHttpServer() {
container.Handle("/api/bulk/users", http.HandlerFunc(m.importExportUsers))
container.Handle("/api/bulk/items", http.HandlerFunc(m.importExportItems))
container.Handle("/api/bulk/feedback", http.HandlerFunc(m.importExportFeedback))
container.Handle("/api/dump", http.HandlerFunc(m.dump))
container.Handle("/api/restore", http.HandlerFunc(m.restore))
if m.workerScheduleHandler == nil {
container.Handle("/api/admin/schedule", http.HandlerFunc(m.scheduleAPIHandler))
} else {
Expand Down Expand Up @@ -1499,3 +1505,306 @@ func (s *Master) checkAdmin(request *http.Request) bool {
}
return false
}

const (
EOF = int64(0)
UserStream = int64(-1)
ItemStream = int64(-2)
FeedbackStream = int64(-3)
)

type DumpStats struct {
Users int
Items int
Feedback int
Duration time.Duration
}

func writeDump[T proto.Message](w io.Writer, data T) error {
bytes, err := proto.Marshal(data)
if err != nil {
return err
}
if err = binary.Write(w, binary.LittleEndian, int64(len(bytes))); err != nil {
return err
}
if _, err = w.Write(bytes); err != nil {
return err
}
return nil
}

func readDump[T proto.Message](r io.Reader, data T) (int64, error) {
var size int64
if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
return 0, err
}
if size <= 0 {
return size, nil
}
bytes := make([]byte, size)
if _, err := r.Read(bytes); err != nil {
return 0, err
}
return size, proto.Unmarshal(bytes, data)
}

func (m *Master) dump(response http.ResponseWriter, request *http.Request) {
if !m.checkAdmin(request) {
writeError(response, http.StatusUnauthorized, "unauthorized")
return
}
if request.Method != http.MethodGet {
writeError(response, http.StatusMethodNotAllowed, "method not allowed")
return
}
response.Header().Set("Content-Type", "application/octet-stream")
var stats DumpStats
start := time.Now()
// dump users
if err := binary.Write(response, binary.LittleEndian, UserStream); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
userStream, errChan := m.DataClient.GetUserStream(context.Background(), batchSize)
for users := range userStream {
for _, user := range users {
labels, err := json.Marshal(user.Labels)
if err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
if err := writeDump(response, &protocol.User{
UserId: user.UserId,
Labels: labels,
Comment: user.Comment,
}); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
stats.Users++
}
}
if err := <-errChan; err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
// dump items
if err := binary.Write(response, binary.LittleEndian, ItemStream); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
itemStream, errChan := m.DataClient.GetItemStream(context.Background(), batchSize, nil)
for items := range itemStream {
for _, item := range items {
labels, err := json.Marshal(item.Labels)
if err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
if err := writeDump(response, &protocol.Item{
ItemId: item.ItemId,
IsHidden: item.IsHidden,
Categories: item.Categories,
Timestamp: timestamppb.New(item.Timestamp),
Labels: labels,
Comment: item.Comment,
}); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
stats.Items++
}
}
if err := <-errChan; err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
// dump feedback
if err := binary.Write(response, binary.LittleEndian, FeedbackStream); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
feedbackStream, errChan := m.DataClient.GetFeedbackStream(context.Background(), batchSize, data.WithEndTime(*m.Config.Now()))
for feedbacks := range feedbackStream {
for _, feedback := range feedbacks {
if err := writeDump(response, &protocol.Feedback{
FeedbackType: feedback.FeedbackType,
UserId: feedback.UserId,
ItemId: feedback.ItemId,
Timestamp: timestamppb.New(feedback.Timestamp),
Comment: feedback.Comment,
}); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
stats.Feedback++
}
}
if err := <-errChan; err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
// dump EOF
if err := binary.Write(response, binary.LittleEndian, EOF); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
stats.Duration = time.Since(start)
log.Logger().Info("complete dump",
zap.Int("users", stats.Users),
zap.Int("items", stats.Items),
zap.Int("feedback", stats.Feedback),
zap.Duration("duration", stats.Duration))
server.Ok(restful.NewResponse(response), stats)
}

func (m *Master) restore(response http.ResponseWriter, request *http.Request) {
if !m.checkAdmin(request) {
writeError(response, http.StatusUnauthorized, "unauthorized")
return
}
if request.Method != http.MethodPost {
writeError(response, http.StatusMethodNotAllowed, "method not allowed")
return
}
var (
flag int64
err error
stats DumpStats
start = time.Now()
)
if err = binary.Read(request.Body, binary.LittleEndian, &flag); err != nil {
if errors.Is(err, io.EOF) {
server.Ok(restful.NewResponse(response), struct{}{})
return
} else {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
}
for flag != EOF {
switch flag {
case UserStream:
users := make([]data.User, 0, batchSize)
for {
var user protocol.User
if flag, err = readDump(request.Body, &user); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
if flag <= 0 {
break
}
labels := make(map[string]interface{})
if err := json.Unmarshal(user.Labels, &labels); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
users = append(users, data.User{
UserId: user.UserId,
Labels: labels,
Comment: user.Comment,
})
stats.Users++
if len(users) == batchSize {
if err := m.DataClient.BatchInsertUsers(context.Background(), users); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
users = users[:0]
}
}
if len(users) > 0 {
if err := m.DataClient.BatchInsertUsers(context.Background(), users); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
}
case ItemStream:
items := make([]data.Item, 0, batchSize)
for {
var item protocol.Item
if flag, err = readDump(request.Body, &item); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
if flag <= 0 {
break
}
labels := make(map[string]interface{})
if err := json.Unmarshal(item.Labels, &labels); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
items = append(items, data.Item{
ItemId: item.ItemId,
IsHidden: item.IsHidden,
Categories: item.Categories,
Timestamp: item.Timestamp.AsTime(),
Labels: labels,
Comment: item.Comment,
})
stats.Items++
if len(items) == batchSize {
if err := m.DataClient.BatchInsertItems(context.Background(), items); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
items = items[:0]
}
}
if len(items) > 0 {
if err := m.DataClient.BatchInsertItems(context.Background(), items); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
}
case FeedbackStream:
feedbacks := make([]data.Feedback, 0, batchSize)
for {
var feedback protocol.Feedback
if flag, err = readDump(request.Body, &feedback); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
if flag <= 0 {
break
}
feedbacks = append(feedbacks, data.Feedback{
FeedbackKey: data.FeedbackKey{
FeedbackType: feedback.FeedbackType,
UserId: feedback.UserId,
ItemId: feedback.ItemId,
},
Timestamp: feedback.Timestamp.AsTime(),
Comment: feedback.Comment,
})
stats.Feedback++
if len(feedbacks) == batchSize {
if err := m.DataClient.BatchInsertFeedback(context.Background(), feedbacks, true, true, true); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
feedbacks = feedbacks[:0]
}
}
if len(feedbacks) > 0 {
if err := m.DataClient.BatchInsertFeedback(context.Background(), feedbacks, true, true, true); err != nil {
writeError(response, http.StatusInternalServerError, err.Error())
return
}
}
default:
writeError(response, http.StatusInternalServerError, fmt.Sprintf("unknown flag %v", flag))
return
}
}
stats.Duration = time.Since(start)
log.Logger().Info("complete restore",
zap.Int("users", stats.Users),
zap.Int("items", stats.Items),
zap.Int("feedback", stats.Feedback),
zap.Duration("duration", stats.Duration))
server.Ok(restful.NewResponse(response), stats)
}
Loading

0 comments on commit d7c351a

Please sign in to comment.