From f224de5080c88a6000c8c592fd6941d085085f66 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Wed, 13 Nov 2024 00:00:40 +0800 Subject: [PATCH] replace import/export format with JSON (#885) --- go.mod | 6 +- go.sum | 4 +- master/rest.go | 524 +++++++++++++++++--------------------------- master/rest_test.go | 335 +++++++++++++--------------- 4 files changed, 358 insertions(+), 511 deletions(-) diff --git a/go.mod b/go.mod index ad9b3ee3d..c903d281f 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/zhenghaoz/gorse -go 1.22 +go 1.23.2 -toolchain go1.23.1 +toolchain go1.23.3 require ( github.com/ReneKroon/ttlcache/v2 v2.11.0 @@ -23,7 +23,7 @@ require ( github.com/golang/protobuf v1.5.2 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4 + github.com/gorse-io/dashboard v0.0.0-20241112140226-19a1b322242c github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 github.com/json-iterator/go v1.1.12 diff --git a/go.sum b/go.sum index ab7e3b29f..45a4eb7cd 100644 --- a/go.sum +++ b/go.sum @@ -301,8 +301,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4 h1:x0bLXsLkjEZdztd0Tw+Hx38vIjzabyj2Fk0EDitKcLk= -github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4/go.mod h1:bv2Yg9Pn4Dca4xPJbvibpF6LH6BjoxcjsEdIuojNano= +github.com/gorse-io/dashboard v0.0.0-20241112140226-19a1b322242c h1:OtOi5F+9Kou/ji0WwiJqVB82sB83279CpzfZcBdnJrU= +github.com/gorse-io/dashboard v0.0.0-20241112140226-19a1b322242c/go.mod h1:iWSDK04UCelym9Uy4YY/tDa6cMGTLpN49Najyhuv35A= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg= diff --git a/master/rest.go b/master/rest.go index 72c641fad..e5ea54de1 100644 --- a/master/rest.go +++ b/master/rest.go @@ -15,7 +15,6 @@ package master import ( - "bufio" "context" "encoding/binary" "encoding/json" @@ -968,24 +967,13 @@ func (m *Master) importExportUsers(response http.ResponseWriter, request *http.R switch request.Method { case http.MethodGet: var err error - response.Header().Set("Content-Type", "text/csv") - response.Header().Set("Content-Disposition", "attachment;filename=users.csv") - // write header - if _, err = response.Write([]byte("user_id,labels\r\n")); err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return - } - // write rows - userChan, errChan := m.DataClient.GetUserStream(ctx, batchSize) - for users := range userChan { + response.Header().Set("Content-Type", "application/jsonl") + response.Header().Set("Content-Disposition", "attachment;filename=users.jsonl") + encoder := json.NewEncoder(response) + userStream, errChan := m.DataClient.GetUserStream(ctx, batchSize) + for users := range userStream { for _, user := range users { - labels, err := json.Marshal(user.Labels) - if err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return - } - if _, err = response.Write([]byte(fmt.Sprintf("%s,%s\r\n", - base.Escape(user.UserId), base.Escape(string(labels))))); err != nil { + if err = encoder.Encode(user); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } @@ -996,89 +984,62 @@ func (m *Master) importExportUsers(response http.ResponseWriter, request *http.R return } case http.MethodPost: - hasHeader := formValue(request, "has-header", "true") == "true" - sep := formValue(request, "sep", ",") - // field separator must be a single character - if len(sep) != 1 { - server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character")) - return - } - labelSep := formValue(request, "label-sep", "|") - fmtString := formValue(request, "format", "ul") + // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() - m.importUsers(ctx, response, file, hasHeader, sep, labelSep, fmtString) - } -} - -func (m *Master) importUsers(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, labelSep, fmtString string) { - - lineCount := 0 - timeStart := time.Now() - users := make([]data.User, 0) - err := base.ReadLines(bufio.NewScanner(file), sep, func(lineNumber int, splits []string) bool { - var err error - // skip header - if hasHeader { - hasHeader = false - return true - } - splits, err = format(fmtString, "ul", splits, lineNumber) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return false - } - // 1. user id - if err = base.ValidateId(splits[0]); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid user id `%v` at line %d (%s)", splits[0], lineNumber, err.Error())) - return false - } - user := data.User{UserId: splits[0]} - // 2. labels - if splits[1] != "" { - var labels any - if err = json.Unmarshal([]byte(splits[1]), &labels); err != nil { + // parse and import users + decoder := json.NewDecoder(file) + lineCount := 0 + timeStart := time.Now() + users := make([]data.User, 0, batchSize) + for { + // parse line + var user data.User + if err = decoder.Decode(&user); err != nil { + if errors.Is(err, io.EOF) { + break + } + server.BadRequest(restful.NewResponse(response), err) + return + } + // validate user id + if err = base.ValidateId(user.UserId); err != nil { server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid labels `%v` at line %d (%s)", splits[1], lineNumber, err.Error())) - return false + fmt.Errorf("invalid user id `%v` at line %d (%s)", user.UserId, lineCount, err.Error())) + return + } + users = append(users, user) + // batch insert + if len(users) == batchSize { + err = m.DataClient.BatchInsertUsers(ctx, users) + if err != nil { + server.InternalServerError(restful.NewResponse(response), err) + return + } + users = make([]data.User, 0, batchSize) } - user.Labels = labels + lineCount++ } - users = append(users, user) - // batch insert - if len(users) == batchSize { + if len(users) > 0 { err = m.DataClient.BatchInsertUsers(ctx, users) if err != nil { server.InternalServerError(restful.NewResponse(response), err) - return false + return } - users = nil - } - lineCount++ - return true - }) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return - } - if len(users) > 0 { - err = m.DataClient.BatchInsertUsers(ctx, users) - if err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return } + m.notifyDataImported() + timeUsed := time.Since(timeStart) + log.Logger().Info("complete import users", + zap.Duration("time_used", timeUsed), + zap.Int("num_users", lineCount)) + server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) + default: + writeError(response, http.StatusMethodNotAllowed, "method not allowed") } - m.notifyDataImported() - timeUsed := time.Since(timeStart) - log.Logger().Info("complete import users", - zap.Duration("time_used", timeUsed), - zap.Int("num_users", lineCount)) - server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) } func (m *Master) importExportItems(response http.ResponseWriter, request *http.Request) { @@ -1098,25 +1059,13 @@ func (m *Master) importExportItems(response http.ResponseWriter, request *http.R switch request.Method { case http.MethodGet: var err error - response.Header().Set("Content-Type", "text/csv") - response.Header().Set("Content-Disposition", "attachment;filename=items.csv") - // write header - if _, err = response.Write([]byte("item_id,is_hidden,categories,time_stamp,labels,description\r\n")); err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return - } - // write rows - itemChan, errChan := m.DataClient.GetItemStream(ctx, batchSize, nil) - for items := range itemChan { + response.Header().Set("Content-Type", "application/jsonl") + response.Header().Set("Content-Disposition", "attachment;filename=items.jsonl") + encoder := json.NewEncoder(response) + itemStream, errChan := m.DataClient.GetItemStream(ctx, batchSize, nil) + for items := range itemStream { for _, item := range items { - labels, err := json.Marshal(item.Labels) - if err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return - } - if _, err = response.Write([]byte(fmt.Sprintf("%s,%t,%s,%v,%s,%s\r\n", - base.Escape(item.ItemId), item.IsHidden, base.Escape(strings.Join(item.Categories, "|")), - item.Timestamp, base.Escape(string(labels)), base.Escape(item.Comment)))); err != nil { + if err = encoder.Encode(item); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } @@ -1127,150 +1076,87 @@ func (m *Master) importExportItems(response http.ResponseWriter, request *http.R return } case http.MethodPost: - hasHeader := formValue(request, "has-header", "true") == "true" - sep := formValue(request, "sep", ",") - // field separator must be a single character - if len(sep) != 1 { - server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character")) - return - } - labelSep := formValue(request, "label-sep", "|") - fmtString := formValue(request, "format", "ihctld") + // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() - m.importItems(ctx, response, file, hasHeader, sep, labelSep, fmtString) - default: - writeError(response, http.StatusMethodNotAllowed, "method not allowed") - } -} - -func (m *Master) importItems(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, labelSep, fmtString string) { - lineCount := 0 - timeStart := time.Now() - items := make([]data.Item, 0) - err := base.ReadLines(bufio.NewScanner(file), sep, func(lineNumber int, splits []string) bool { - var err error - // skip header - if hasHeader { - hasHeader = false - return true - } - splits, err = format(fmtString, "ihctld", splits, lineNumber) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return false - } - // 1. item id - if err = base.ValidateId(splits[0]); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid item id `%v` at line %d (%s)", splits[0], lineNumber, err.Error())) - return false - } - item := data.Item{ItemId: splits[0]} - // 2. hidden - if splits[1] != "" { - item.IsHidden, err = strconv.ParseBool(splits[1]) - if err != nil { + // parse and import items + decoder := json.NewDecoder(file) + lineCount := 0 + timeStart := time.Now() + items := make([]data.Item, 0, batchSize) + for { + // parse line + var item server.Item + if err = decoder.Decode(&item); err != nil { + if errors.Is(err, io.EOF) { + break + } + server.BadRequest(restful.NewResponse(response), err) + return + } + // validate item id + if err = base.ValidateId(item.ItemId); err != nil { server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid hidden value `%v` at line %d (%s)", splits[1], lineNumber, err.Error())) - return false + fmt.Errorf("invalid item id `%v` at line %d (%s)", item.ItemId, lineCount, err.Error())) + return } - } - // 3. categories - if splits[2] != "" { - item.Categories = strings.Split(splits[2], labelSep) + // validate categories for _, category := range item.Categories { if err = base.ValidateId(category); err != nil { server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid category `%v` at line %d (%s)", category, lineNumber, err.Error())) - return false + fmt.Errorf("invalid category `%v` at line %d (%s)", category, lineCount, err.Error())) + return } } - } - // 4. timestamp - if splits[3] != "" { - item.Timestamp, err = dateparse.ParseAny(splits[3]) - if err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("failed to parse datetime `%v` at line %v", splits[1], lineNumber)) - return false + // parse timestamp + var timestamp time.Time + if item.Timestamp != "" { + timestamp, err = dateparse.ParseAny(item.Timestamp) + if err != nil { + server.BadRequest(restful.NewResponse(response), + fmt.Errorf("failed to parse datetime `%v` at line %v", item.Timestamp, lineCount)) + return + } } - } - // 5. labels - if splits[4] != "" { - var labels any - if err = json.Unmarshal([]byte(splits[4]), &labels); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("failed to parse labels `%v` at line %v", splits[4], lineNumber)) - return false + items = append(items, data.Item{ + ItemId: item.ItemId, + IsHidden: item.IsHidden, + Categories: item.Categories, + Timestamp: timestamp, + Labels: item.Labels, + Comment: item.Comment, + }) + // batch insert + if len(items) == batchSize { + err = m.DataClient.BatchInsertItems(ctx, items) + if err != nil { + server.InternalServerError(restful.NewResponse(response), err) + return + } + items = make([]data.Item, 0, batchSize) } - item.Labels = labels + lineCount++ } - // 6. comment - item.Comment = splits[5] - items = append(items, item) - // batch insert - if len(items) == batchSize { + if len(items) > 0 { err = m.DataClient.BatchInsertItems(ctx, items) if err != nil { server.InternalServerError(restful.NewResponse(response), err) - return false + return } - items = nil - } - lineCount++ - return true - }) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return - } - if len(items) > 0 { - err = m.DataClient.BatchInsertItems(ctx, items) - if err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return } + m.notifyDataImported() + timeUsed := time.Since(timeStart) + log.Logger().Info("complete import items", + zap.Duration("time_used", timeUsed), + zap.Int("num_items", lineCount)) + server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) + default: + writeError(response, http.StatusMethodNotAllowed, "method not allowed") } - m.notifyDataImported() - timeUsed := time.Since(timeStart) - log.Logger().Info("complete import items", - zap.Duration("time_used", timeUsed), - zap.Int("num_items", lineCount)) - server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) -} - -func format(inFmt, outFmt string, s []string, lineCount int) ([]string, error) { - if len(s) < len(inFmt) { - log.Logger().Error("number of fields mismatch", - zap.Int("expect", len(inFmt)), - zap.Int("actual", len(s))) - return nil, fmt.Errorf("number of fields mismatch at line %v", lineCount) - } - if inFmt == outFmt { - return s, nil - } - pool := make(map[uint8]string) - for i := range inFmt { - pool[inFmt[i]] = s[i] - } - out := make([]string, len(outFmt)) - for i, c := range outFmt { - out[i] = pool[uint8(c)] - } - return out, nil -} - -func formValue(request *http.Request, fieldName, defaultValue string) string { - value := request.FormValue(fieldName) - if value == "" { - return defaultValue - } - return value } func (m *Master) importExportFeedback(response http.ResponseWriter, request *http.Request) { @@ -1285,19 +1171,13 @@ func (m *Master) importExportFeedback(response http.ResponseWriter, request *htt switch request.Method { case http.MethodGet: var err error - response.Header().Set("Content-Type", "text/csv") - response.Header().Set("Content-Disposition", "attachment;filename=feedback.csv") - // write header - if _, err = response.Write([]byte("feedback_type,user_id,item_id,time_stamp\r\n")); err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return - } - // write rows - feedbackChan, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, data.WithEndTime(*m.Config.Now())) - for feedback := range feedbackChan { + response.Header().Set("Content-Type", "application/jsonl") + response.Header().Set("Content-Disposition", "attachment;filename=feedback.jsonl") + encoder := json.NewEncoder(response) + feedbackStream, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, data.WithEndTime(*m.Config.Now())) + for feedback := range feedbackStream { for _, v := range feedback { - if _, err = response.Write([]byte(fmt.Sprintf("%s,%s,%s,%v\r\n", - base.Escape(v.FeedbackType), base.Escape(v.UserId), base.Escape(v.ItemId), v.Timestamp))); err != nil { + if err = encoder.Encode(v); err != nil { server.InternalServerError(restful.NewResponse(response), err) return } @@ -1308,109 +1188,95 @@ func (m *Master) importExportFeedback(response http.ResponseWriter, request *htt return } case http.MethodPost: - hasHeader := formValue(request, "has-header", "true") == "true" - sep := formValue(request, "sep", ",") - // field separator must be a single character - if len(sep) != 1 { - server.BadRequest(restful.NewResponse(response), fmt.Errorf("field separator must be a single character")) - return - } - fmtString := formValue(request, "format", "fuit") - // import items + // open file file, _, err := request.FormFile("file") if err != nil { server.BadRequest(restful.NewResponse(response), err) return } defer file.Close() - m.importFeedback(ctx, response, file, hasHeader, sep, fmtString) - default: - writeError(response, http.StatusMethodNotAllowed, "method not allowed") - } -} - -func (m *Master) importFeedback(ctx context.Context, response http.ResponseWriter, file io.Reader, hasHeader bool, sep, fmtString string) { - var err error - scanner := bufio.NewScanner(file) - lineCount := 0 - timeStart := time.Now() - feedbacks := make([]data.Feedback, 0) - err = base.ReadLines(scanner, sep, func(lineNumber int, splits []string) bool { - if hasHeader { - hasHeader = false - return true - } - // reorder fields - splits, err = format(fmtString, "fuit", splits, lineNumber) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return false - } - feedback := data.Feedback{} - // 1. feedback type - feedback.FeedbackType = splits[0] - if err = base.ValidateId(splits[0]); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid feedback type `%v` at line %d (%s)", splits[0], lineNumber, err.Error())) - return false - } - // 2. user id - if err = base.ValidateId(splits[1]); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid user id `%v` at line %d (%s)", splits[1], lineNumber, err.Error())) - return false - } - feedback.UserId = splits[1] - // 3. item id - if err = base.ValidateId(splits[2]); err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("invalid item id `%v` at line %d (%s)", splits[2], lineNumber, err.Error())) - return false - } - feedback.ItemId = splits[2] - feedback.Timestamp, err = dateparse.ParseAny(splits[3]) - if err != nil { - server.BadRequest(restful.NewResponse(response), - fmt.Errorf("failed to parse datetime `%v` at line %d", splits[3], lineNumber)) - return false + // parse and import feedback + decoder := json.NewDecoder(file) + lineCount := 0 + timeStart := time.Now() + feedbacks := make([]data.Feedback, 0, batchSize) + for { + // parse line + var feedback server.Feedback + if err = decoder.Decode(&feedback); err != nil { + if errors.Is(err, io.EOF) { + break + } + server.BadRequest(restful.NewResponse(response), err) + return + } + // validate feedback type + if err = base.ValidateId(feedback.FeedbackType); err != nil { + server.BadRequest(restful.NewResponse(response), + fmt.Errorf("invalid feedback type `%v` at line %d (%s)", feedback.FeedbackType, lineCount, err.Error())) + return + } + // validate user id + if err = base.ValidateId(feedback.UserId); err != nil { + server.BadRequest(restful.NewResponse(response), + fmt.Errorf("invalid user id `%v` at line %d (%s)", feedback.UserId, lineCount, err.Error())) + return + } + // validate item id + if err = base.ValidateId(feedback.ItemId); err != nil { + server.BadRequest(restful.NewResponse(response), + fmt.Errorf("invalid item id `%v` at line %d (%s)", feedback.ItemId, lineCount, err.Error())) + return + } + // parse timestamp + var timestamp time.Time + if feedback.Timestamp != "" { + timestamp, err = dateparse.ParseAny(feedback.Timestamp) + if err != nil { + server.BadRequest(restful.NewResponse(response), + fmt.Errorf("failed to parse datetime `%v` at line %d", feedback.Timestamp, lineCount)) + return + } + } + feedbacks = append(feedbacks, data.Feedback{ + FeedbackKey: feedback.FeedbackKey, + Timestamp: timestamp, + Comment: feedback.Comment, + }) + // batch insert + if len(feedbacks) == batchSize { + // batch insert to data store + err = m.DataClient.BatchInsertFeedback(ctx, feedbacks, + m.Config.Server.AutoInsertUser, + m.Config.Server.AutoInsertItem, true) + if err != nil { + server.InternalServerError(restful.NewResponse(response), err) + return + } + feedbacks = make([]data.Feedback, 0, batchSize) + } + lineCount++ } - feedbacks = append(feedbacks, feedback) - // batch insert - if len(feedbacks) == batchSize { - // batch insert to data store + // insert to cache store + if len(feedbacks) > 0 { + // insert to data store err = m.DataClient.BatchInsertFeedback(ctx, feedbacks, m.Config.Server.AutoInsertUser, m.Config.Server.AutoInsertItem, true) if err != nil { server.InternalServerError(restful.NewResponse(response), err) - return false + return } - feedbacks = nil - } - lineCount++ - return true - }) - if err != nil { - server.BadRequest(restful.NewResponse(response), err) - return - } - // insert to cache store - if len(feedbacks) > 0 { - // insert to data store - err = m.DataClient.BatchInsertFeedback(ctx, feedbacks, - m.Config.Server.AutoInsertUser, - m.Config.Server.AutoInsertItem, true) - if err != nil { - server.InternalServerError(restful.NewResponse(response), err) - return } + m.notifyDataImported() + timeUsed := time.Since(timeStart) + log.Logger().Info("complete import feedback", + zap.Duration("time_used", timeUsed), + zap.Int("num_items", lineCount)) + server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) + default: + writeError(response, http.StatusMethodNotAllowed, "method not allowed") } - m.notifyDataImported() - timeUsed := time.Since(timeStart) - log.Logger().Info("complete import feedback", - zap.Duration("time_used", timeUsed), - zap.Int("num_items", lineCount)) - server.Ok(restful.NewResponse(response), server.Success{RowAffected: lineCount}) } var checkList = mapset.NewSet("delete_users", "delete_items", "delete_feedback", "delete_cache") @@ -1543,7 +1409,7 @@ func readDump[T proto.Message](r io.Reader, data T) (int64, error) { return size, nil } bytes := make([]byte, size) - if _, err := r.Read(bytes); err != nil { + if _, err := io.ReadFull(r, bytes); err != nil { return 0, err } return size, proto.Unmarshal(bytes, data) @@ -1696,7 +1562,7 @@ func (m *Master) restore(response http.ResponseWriter, request *http.Request) { if flag <= 0 { break } - labels := make(map[string]interface{}) + var labels any if err := json.Unmarshal(user.Labels, &labels); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return @@ -1732,7 +1598,7 @@ func (m *Master) restore(response http.ResponseWriter, request *http.Request) { if flag <= 0 { break } - labels := make(map[string]interface{}) + var labels any if err := json.Unmarshal(item.Labels, &labels); err != nil { writeError(response, http.StatusInternalServerError, err.Error()) return diff --git a/master/rest_test.go b/master/rest_test.go index 692d4f239..2f7321be1 100644 --- a/master/rest_test.go +++ b/master/rest_test.go @@ -101,6 +101,16 @@ func marshal(t *testing.T, v interface{}) string { return string(s) } +func marshalJSONLines[T any](t *testing.T, v []T) string { + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + for _, item := range v { + err := encoder.Encode(item) + assert.NoError(t, err) + } + return buf.String() +} + func convertToMapStructure(t *testing.T, v interface{}) map[string]interface{} { var m map[string]interface{} err := mapstructure.Decode(v, &m) @@ -126,12 +136,9 @@ func TestMaster_ExportUsers(t *testing.T) { w := httptest.NewRecorder() s.importExportUsers(w, req) assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=users.csv", w.Header().Get("Content-Disposition")) - assert.Equal(t, "user_id,labels\r\n"+ - "1,\"{\"\"gender\"\":\"\"male\"\",\"\"job\"\":\"\"engineer\"\"}\"\r\n"+ - "2,\"{\"\"gender\"\":\"\"male\"\",\"\"job\"\":\"\"lawyer\"\"}\"\r\n"+ - "3,\"{\"\"gender\"\":\"\"female\"\",\"\"job\"\":\"\"teacher\"\"}\"\r\n", w.Body.String()) + assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) + assert.Equal(t, "attachment;filename=users.jsonl", w.Header().Get("Content-Disposition")) + assert.Equal(t, marshalJSONLines(t, users), w.Body.String()) } func TestMaster_ExportItems(t *testing.T) { @@ -173,12 +180,9 @@ func TestMaster_ExportItems(t *testing.T) { w := httptest.NewRecorder() s.importExportItems(w, req) assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=items.csv", w.Header().Get("Content-Disposition")) - assert.Equal(t, "item_id,is_hidden,categories,time_stamp,labels,description\r\n"+ - "1,false,x,2020-01-01 01:01:01.000000001 +0000 UTC,\"{\"\"genre\"\":[\"\"comedy\"\",\"\"sci-fi\"\"]}\",\"o,n,e\"\r\n"+ - "2,false,x|y,2021-01-01 01:01:01.000000001 +0000 UTC,\"{\"\"genre\"\":[\"\"documentary\"\",\"\"sci-fi\"\"]}\",\"t\r\nw\r\no\"\r\n"+ - "3,true,,2022-01-01 01:01:01.000000001 +0000 UTC,null,\"\"\"three\"\"\"\r\n", w.Body.String()) + assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) + assert.Equal(t, "attachment;filename=items.jsonl", w.Header().Get("Content-Disposition")) + assert.Equal(t, marshalJSONLines(t, items), w.Body.String()) } func TestMaster_ExportFeedback(t *testing.T) { @@ -189,8 +193,8 @@ func TestMaster_ExportFeedback(t *testing.T) { // insert feedback feedbacks := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, - {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, } err := s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) assert.NoError(t, err) @@ -200,68 +204,23 @@ func TestMaster_ExportFeedback(t *testing.T) { w := httptest.NewRecorder() s.importExportFeedback(w, req) assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=feedback.csv", w.Header().Get("Content-Disposition")) - assert.Equal(t, "feedback_type,user_id,item_id,time_stamp\r\n"+ - "click,0,2,0001-01-01 00:00:00 +0000 UTC\r\n"+ - "read,2,6,0001-01-01 00:00:00 +0000 UTC\r\n"+ - "share,1,4,0001-01-01 00:00:00 +0000 UTC\r\n", w.Body.String()) + assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) + assert.Equal(t, "attachment;filename=feedback.jsonl", w.Header().Get("Content-Disposition")) + assert.Equal(t, marshalJSONLines(t, feedbacks), w.Body.String()) } func TestMaster_ImportUsers(t *testing.T) { s, cookie := newMockServer(t) defer s.Close(t) - - ctx := context.Background() - // send request - buf := bytes.NewBuffer(nil) - writer := multipart.NewWriter(buf) - err := writer.WriteField("has-header", "false") - assert.NoError(t, err) - err = writer.WriteField("sep", "\t") - assert.NoError(t, err) - err = writer.WriteField("label-sep", "::") - assert.NoError(t, err) - err = writer.WriteField("format", "lu") - assert.NoError(t, err) - file, err := writer.CreateFormFile("file", "users.csv") - assert.NoError(t, err) - _, err = file.Write([]byte("\"{\"\"gender\"\":\"\"male\"\",\"\"job\"\":\"\"engineer\"\"}\"\t1\n" + - "\"{\"\"gender\"\":\"\"male\"\",\"\"job\"\":\"\"lawyer\"\"}\"\t2\n" + - "\"{\"\"gender\"\":\"\"female\"\",\"\"job\"\":\"\"teacher\"\"}\"\t\"3\"\n")) - assert.NoError(t, err) - err = writer.Close() - assert.NoError(t, err) - req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) - req.Header.Set("Content-Type", writer.FormDataContentType()) - w := httptest.NewRecorder() - s.importExportUsers(w, req) - // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.DataClient.GetUsers(ctx, "", 100) - assert.NoError(t, err) - assert.Equal(t, []data.User{ - {UserId: "1", Labels: map[string]any{"gender": "male", "job": "engineer"}}, - {UserId: "2", Labels: map[string]any{"gender": "male", "job": "lawyer"}}, - {UserId: "3", Labels: map[string]any{"gender": "female", "job": "teacher"}}, - }, items) -} - -func TestMaster_ImportUsers_DefaultFormat(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) ctx := context.Background() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) - file, err := writer.CreateFormFile("file", "users.csv") + file, err := writer.CreateFormFile("file", "users.jsonl") assert.NoError(t, err) - _, err = file.Write([]byte("user_id,labels\r\n" + - "1,\"{\"\"性别\"\":\"\"男\"\",\"\"职业\"\":\"\"工程师\"\"}\"\r\n" + - "2,\"{\"\"性别\"\":\"\"男\"\",\"\"职业\"\":\"\"律师\"\"}\"\r\n" + - "\"3\",\"{\"\"性别\"\":\"\"女\"\",\"\"职业\"\":\"\"教师\"\"}\"\r\n")) + _, err = file.Write([]byte(`{"UserId":"1","Labels":{"性别":"男","职业":"工程师"}} +{"UserId":"2","Labels":{"性别":"男","职业":"律师"}} +{"UserId":"3","Labels":{"性别":"女","职业":"教师"}}`)) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) @@ -285,79 +244,15 @@ func TestMaster_ImportUsers_DefaultFormat(t *testing.T) { func TestMaster_ImportItems(t *testing.T) { s, cookie := newMockServer(t) defer s.Close(t) - ctx := context.Background() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) - err := writer.WriteField("has-header", "false") - assert.NoError(t, err) - err = writer.WriteField("sep", "\t") - assert.NoError(t, err) - err = writer.WriteField("label-sep", "::") - assert.NoError(t, err) - err = writer.WriteField("format", "ildtch") + file, err := writer.CreateFormFile("file", "items.jsonl") assert.NoError(t, err) - file, err := writer.CreateFormFile("file", "items.csv") - assert.NoError(t, err) - _, err = file.Write([]byte("1\t\"{\"\"genre\"\":[\"\"comedy\"\",\"\"sci-fi\"\"]}\"\t\"o,n,e\"\t2020-01-01 01:01:01.000000001 +0000 UTC\tx\t0\n" + - "2\t\"{\"\"genre\"\":[\"\"documentary\"\",\"\"sci-fi\"\"]}\"\t\"t\r\nw\r\no\"\t2021-01-01 01:01:01.000000001 +0000 UTC\tx::y\t0\n" + - "\"3\"\t\"\"\t\"\"\"three\"\"\"\t\"2022-01-01 01:01:01.000000001 +0000 UTC\"\t\t\"1\"\n")) - assert.NoError(t, err) - err = writer.Close() - assert.NoError(t, err) - req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) - req.Header.Set("Content-Type", writer.FormDataContentType()) - w := httptest.NewRecorder() - s.importExportItems(w, req) - // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.DataClient.GetItems(ctx, "", 100, nil) - assert.NoError(t, err) - assert.Equal(t, []data.Item{ - { - ItemId: "1", - IsHidden: false, - Categories: []string{"x"}, - Timestamp: time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC), - Labels: map[string]any{"genre": []any{"comedy", "sci-fi"}}, - Comment: "o,n,e", - }, - { - ItemId: "2", - IsHidden: false, - Categories: []string{"x", "y"}, - Timestamp: time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC), - Labels: map[string]any{"genre": []any{"documentary", "sci-fi"}}, - Comment: "t\r\nw\r\no", - }, - { - ItemId: "3", - IsHidden: true, - Categories: nil, - Timestamp: time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC), - Labels: nil, - Comment: "\"three\"", - }, - }, items) -} - -func TestMaster_ImportItems_DefaultFormat(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - - ctx := context.Background() - // send request - buf := bytes.NewBuffer(nil) - writer := multipart.NewWriter(buf) - file, err := writer.CreateFormFile("file", "items.csv") - assert.NoError(t, err) - _, err = file.Write([]byte("item_id,is_hidden,categories,time_stamp,labels,description\r\n" + - "1,false,x,2020-01-01 01:01:01.000000001 +0000 UTC,\"{\"\"类型\"\":[\"\"喜剧\"\",\"\"科幻\"\"]}\",one\r\n" + - "2,false,x|y,2021-01-01 01:01:01.000000001 +0000 UTC,\"{\"\"类型\"\":[\"\"卡通\"\",\"\"科幻\"\"]}\",two\r\n" + - "\"3\",\"true\",,\"2022-01-01 01:01:01.000000001 +0000 UTC\",,\"three\"\r\n")) + _, err = file.Write([]byte(`{"ItemId":"1","IsHidden":false,"Categories":["x"],"Timestamp":"2020-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["喜剧","科幻"]},"Comment":"one"} +{"ItemId":"2","IsHidden":false,"Categories":["x","y"],"Timestamp":"2021-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["卡通","科幻"]},"Comment":"two"} +{"ItemId":"3","IsHidden":true,"Timestamp":"2022-01-01 01:01:01.000000001 +0000 UTC","Comment":"three"}`)) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) @@ -401,55 +296,15 @@ func TestMaster_ImportItems_DefaultFormat(t *testing.T) { func TestMaster_ImportFeedback(t *testing.T) { s, cookie := newMockServer(t) defer s.Close(t) - - ctx := context.Background() - // send request - buf := bytes.NewBuffer(nil) - writer := multipart.NewWriter(buf) - err := writer.WriteField("format", "uift") - assert.NoError(t, err) - err = writer.WriteField("sep", "\t") - assert.NoError(t, err) - err = writer.WriteField("has-header", "false") - assert.NoError(t, err) - file, err := writer.CreateFormFile("file", "feedback.csv") - assert.NoError(t, err) - _, err = file.Write([]byte("0\t2\tclick\t0001-01-01 00:00:00 +0000 UTC\n" + - "2\t6\tread\t0001-01-01 00:00:00 +0000 UTC\n" + - "\"1\"\t\"4\"\t\"share\"\t\"0001-01-01 00:00:00 +0000 UTC\"\n")) - assert.NoError(t, err) - err = writer.Close() - assert.NoError(t, err) - req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) - req.Header.Set("Content-Type", writer.FormDataContentType()) - w := httptest.NewRecorder() - s.importExportFeedback(w, req) - // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, feedback, err := s.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - assert.Equal(t, []data.Feedback{ - {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, - {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, - {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, - }, feedback) -} - -func TestMaster_ImportFeedback_Default(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) // send request ctx := context.Background() buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) - file, err := writer.CreateFormFile("file", "feedback.csv") + file, err := writer.CreateFormFile("file", "feedback.jsonl") assert.NoError(t, err) - _, err = file.Write([]byte("feedback_type,user_id,item_id,time_stamp\r\n" + - "click,0,2,0001-01-01 00:00:00 +0000 UTC\r\n" + - "read,2,6,0001-01-01 00:00:00 +0000 UTC\r\n" + - "\"share\",\"1\",\"4\",\"0001-01-01 00:00:00 +0000 UTC\"\r\n")) + _, err = file.Write([]byte(`{"FeedbackType":"click","UserId":"0","ItemId":"2","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} +{"FeedbackType":"read","UserId":"2","ItemId":"6","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} +{"FeedbackType":"share","UserId":"1","ItemId":"4","Timestamp":"0001-01-01 00:00:00 +0000 UTC"}`)) assert.NoError(t, err) err = writer.Close() assert.NoError(t, err) @@ -1064,3 +919,129 @@ func TestDumpAndRestore(t *testing.T) { assert.Equal(t, feedback, returnFeedback) } } + +func TestExportAndImport(t *testing.T) { + s, cookie := newMockServer(t) + defer s.Close(t) + ctx := context.Background() + // insert users + users := make([]data.User, batchSize+1) + for i := range users { + users[i] = data.User{ + UserId: fmt.Sprintf("%05d", i), + Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, + } + } + err := s.DataClient.BatchInsertUsers(ctx, users) + assert.NoError(t, err) + // insert items + items := make([]data.Item, batchSize+1) + for i := range items { + items[i] = data.Item{ + ItemId: fmt.Sprintf("%05d", i), + Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, + } + } + err = s.DataClient.BatchInsertItems(ctx, items) + assert.NoError(t, err) + // insert feedback + feedback := make([]data.Feedback, batchSize+1) + for i := range feedback { + feedback[i] = data.Feedback{ + FeedbackKey: data.FeedbackKey{ + FeedbackType: "click", + UserId: fmt.Sprintf("%05d", i), + ItemId: fmt.Sprintf("%05d", i), + }, + } + } + err = s.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) + assert.NoError(t, err) + + // export users + req := httptest.NewRequest("GET", "https://example.com/", nil) + req.Header.Set("Cookie", cookie) + w := httptest.NewRecorder() + s.importExportUsers(w, req) + assert.Equal(t, http.StatusOK, w.Code) + usersData := w.Body.Bytes() + // export items + req = httptest.NewRequest("GET", "https://example.com/", nil) + req.Header.Set("Cookie", cookie) + w = httptest.NewRecorder() + s.importExportItems(w, req) + assert.Equal(t, http.StatusOK, w.Code) + itemsData := w.Body.Bytes() + // export feedback + req = httptest.NewRequest("GET", "https://example.com/", nil) + req.Header.Set("Cookie", cookie) + w = httptest.NewRecorder() + s.importExportFeedback(w, req) + assert.Equal(t, http.StatusOK, w.Code) + feedbackData := w.Body.Bytes() + + err = s.DataClient.Purge() + assert.NoError(t, err) + // import users + buf := bytes.NewBuffer(nil) + writer := multipart.NewWriter(buf) + file, err := writer.CreateFormFile("file", "users.jsonl") + assert.NoError(t, err) + _, err = file.Write(usersData) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + req = httptest.NewRequest("POST", "https://example.com/", buf) + req.Header.Set("Cookie", cookie) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w = httptest.NewRecorder() + s.importExportUsers(w, req) + assert.Equal(t, http.StatusOK, w.Code) + // import items + buf = bytes.NewBuffer(nil) + writer = multipart.NewWriter(buf) + file, err = writer.CreateFormFile("file", "items.jsonl") + assert.NoError(t, err) + _, err = file.Write(itemsData) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + req = httptest.NewRequest("POST", "https://example.com/", buf) + req.Header.Set("Cookie", cookie) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w = httptest.NewRecorder() + s.importExportItems(w, req) + assert.Equal(t, http.StatusOK, w.Code) + // import feedback + buf = bytes.NewBuffer(nil) + writer = multipart.NewWriter(buf) + file, err = writer.CreateFormFile("file", "feedback.jsonl") + assert.NoError(t, err) + _, err = file.Write(feedbackData) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + req = httptest.NewRequest("POST", "https://example.com/", buf) + req.Header.Set("Cookie", cookie) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w = httptest.NewRecorder() + s.importExportFeedback(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // check data + _, returnUsers, err := s.DataClient.GetUsers(ctx, "", len(users)) + assert.NoError(t, err) + if assert.Equal(t, len(users), len(returnUsers)) { + assert.Equal(t, users, returnUsers) + } + _, returnItems, err := s.DataClient.GetItems(ctx, "", len(items), nil) + assert.NoError(t, err) + if assert.Equal(t, len(items), len(returnItems)) { + assert.Equal(t, items, returnItems) + } + _, returnFeedback, err := s.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) + assert.NoError(t, err) + if assert.Equal(t, len(feedback), len(returnFeedback)) { + assert.Equal(t, feedback, returnFeedback) + } +}