diff --git a/jprov/queue/queue.go b/jprov/queue/queue.go index 0851e45..0781344 100644 --- a/jprov/queue/queue.go +++ b/jprov/queue/queue.go @@ -1,6 +1,7 @@ package queue import ( + "errors" "fmt" "time" @@ -35,6 +36,58 @@ func (q *UploadQueue) Append(upload *types.Upload) { q.Queue = append(q.Queue, upload) } +// Return a list of messages of the upload queue up to maxMessageSize in FCFS order. +// Returns nil in these conditions: +// 1. maxMessageSize is too small +// 2. the UploadQueue is locked +// 3. the upload queue is empty +func (q *UploadQueue) PrepareMessage(maxMessageSize int) (messages []cosmosTypes.Msg) { + if maxMessageSize < 1 || !q.Locked || len(q.Queue) == 0 { + return nil + } + + var netMsgSize int + for _, q := range q.Queue { + msgSize := len(q.Message.String()) + + if netMsgSize+msgSize > maxMessageSize { + break + } else { + netMsgSize += msgSize + messages = append(messages, q.Message) + } + } + + return +} + +// Update the upload queue with the parameter fields of count +func (q *UploadQueue) UpdateQueue(count int, err error, res *cosmosTypes.TxResponse) { + if !q.Locked || len(q.Queue) == 0 || len(q.Queue) < count { + return + } + + for i := 0; i < count; i++ { + q := q.Queue[i] + + if err != nil { + q.Err = err + } else { + if res != nil { + if res.Code != 0 { + q.Err = errors.New(res.RawLog) + } else { + q.Response = res + } + } + } + + if q.Callback != nil { + q.Callback.Done() + } + } +} + func (q *UploadQueue) listenOnce(cmd *cobra.Command, providerName string) { if q.Locked { return @@ -46,9 +99,7 @@ func (q *UploadQueue) listenOnce(cmd *cobra.Command, providerName string) { ctx := utils.GetServerContextFromCmd(cmd) - l := len(q.Queue) - - if l == 0 { + if len(q.Queue) == 0 { return } @@ -57,56 +108,16 @@ func (q *UploadQueue) listenOnce(cmd *cobra.Command, providerName string) { ctx.Logger.Error(err.Error()) } - var totalSizeOfMsgs int - msgs := make([]cosmosTypes.Msg, 0) - uploads := make([]*types.Upload, 0) - - for i := 0; i < l; i++ { // loop through entire queue - - upload := q.Queue[i] - - uploadSize := len(upload.Message.String()) - - // if the size of the upload would put us past our cap, we cut off the queue and send only what fits - if totalSizeOfMsgs+uploadSize > maxSize { - msgs = msgs[:len(msgs)-1] - uploads = uploads[:len(uploads)-1] - l = i - - break - } else { - uploads = append(uploads, upload) - msgs = append(msgs, upload.Message) - totalSizeOfMsgs += len(upload.Message.String()) - } - - } + msgs := q.PrepareMessage(maxSize) clientCtx := client.GetClientContextFromCmd(cmd) ctx.Logger.Debug(fmt.Sprintf("total no. of msgs in proof transaction is: %d", len(msgs))) res, err := utils.SendTx(clientCtx, cmd.Flags(), fmt.Sprintf("Storage Provided by %s", providerName), msgs...) - for _, v := range uploads { - if v == nil { - continue - } - if err != nil { - v.Err = err - } else { - if res != nil { - if res.Code != 0 { - v.Err = fmt.Errorf(res.RawLog) - } else { - v.Response = res - } - } - } - if v.Callback != nil { - v.Callback.Done() - } - } - q.Queue = q.Queue[l:] // pop every upload that fit off the queue + q.UpdateQueue(len(msgs), err, res) + + q.Queue = q.Queue[len(msgs):] // pop every upload that fit off the queue } func (q *UploadQueue) StartListener(cmd *cobra.Command, providerName string) { diff --git a/jprov/queue/queue_test.go b/jprov/queue/queue_test.go index 363c12f..ee9f2db 100644 --- a/jprov/queue/queue_test.go +++ b/jprov/queue/queue_test.go @@ -24,6 +24,20 @@ func setupQueue(t *testing.T) (queue.UploadQueue, *require.Assertions) { return q, require } +func setupUpload(count int) (upload []*types.Upload) { + for i := 0; i < count; i++ { + msg := storagetypes.NewMsgInitProvider( + "test-address", + "localhost:3333", + "1000", + "test-key", + ) + upload = append(upload, &types.Upload{Message: msg}) + } + + return +} + func TestAppend(t *testing.T) { q, require := setupQueue(t) @@ -56,3 +70,50 @@ func TestAppend(t *testing.T) { require.Equal(stringQueue, string(data)) } + +func TestPrepareMessage(t *testing.T) { + cases := map[string]struct { + uq queue.UploadQueue + maxMsgSize int + resultSize int + }{ + "empty_queue": { + uq: queue.UploadQueue{ + Locked: true, + }, + maxMsgSize: 10, + resultSize: 0, + }, + "queue_exceed_max": { + uq: queue.UploadQueue{ + Locked: true, + Queue: setupUpload(10), + }, + maxMsgSize: 1, + resultSize: 0, + }, + "queue_msg_length": { + uq: queue.UploadQueue{ + Locked: true, + Queue: setupUpload(1), + }, + maxMsgSize: 500, + resultSize: len(setupUpload(1)[0].Message.String()), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + msgs := c.uq.PrepareMessage(c.maxMsgSize) + var msgSize int + for _, m := range msgs { + msgSize += len(m.String()) + } + + if c.resultSize != msgSize { + t.Log("Expected size: ", c.resultSize, " Result size: ", msgSize) + t.Fail() + } + }) + } +} diff --git a/jprov/server/attestation.go b/jprov/server/attestation.go index f233756..ace1689 100644 --- a/jprov/server/attestation.go +++ b/jprov/server/attestation.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strconv" @@ -18,6 +19,40 @@ import ( "github.com/spf13/cobra" ) +func verifyAttest(deal storageTypes.ActiveDeals, attest types.AttestRequest) (verified bool, err error) { + merkle := deal.Merkle + block := deal.Blocktoprove + blockNum, err := strconv.ParseInt(block, 10, 64) + if err != nil { + return false, err + } + + verified = storageKeeper.VerifyDeal(merkle, attest.HashList, blockNum, attest.Item) + + return +} + +func addMsgAttest(address string, cid string, q *queue.UploadQueue) (upload types.Upload, err error) { + msg := storageTypes.NewMsgAttest(address, cid) + + if err := msg.ValidateBasic(); err != nil { + return upload, err + } + + var wg sync.WaitGroup + wg.Add(1) + + upload = types.Upload{ + Message: msg, + Err: nil, + Callback: &wg, + Response: nil, + } + + q.Append(&upload) + return +} + func attest(w *http.ResponseWriter, r *http.Request, cmd *cobra.Command, q *queue.UploadQueue) { clientCtx, qerr := client.GetClientTxContext(cmd) if qerr != nil { @@ -45,22 +80,16 @@ func attest(w *http.ResponseWriter, r *http.Request, cmd *cobra.Command, q *queu deal, err := queryClient.ActiveDeals(context.Background(), dealReq) if err != nil { http.Error(*w, err.Error(), http.StatusBadRequest) - return } - merkle := deal.ActiveDeals.Merkle - block := deal.ActiveDeals.Blocktoprove - blockNum, err := strconv.ParseInt(block, 10, 64) + verified, err := verifyAttest(deal.ActiveDeals, attest) if err != nil { http.Error(*w, err.Error(), http.StatusBadRequest) return } - verified := storageKeeper.VerifyDeal(merkle, attest.HashList, blockNum, attest.Item) - if !verified { - http.Error(*w, err.Error(), http.StatusBadRequest) - return + http.Error(*w, errors.New("failed to verify attest").Error(), http.StatusBadRequest) } address, err := crypto.GetAddress(clientCtx) @@ -69,35 +98,21 @@ func attest(w *http.ResponseWriter, r *http.Request, cmd *cobra.Command, q *queu return } - msg := storageTypes.NewMsgAttest( // create new attest - address, - attest.Cid, - ) - if err := msg.ValidateBasic(); err != nil { + upload, err := addMsgAttest(address, attest.Cid, q) + if err != nil { http.Error(*w, err.Error(), http.StatusBadRequest) return } - var wg sync.WaitGroup - wg.Add(1) - - u := types.Upload{ - Message: msg, - Err: nil, - Callback: &wg, - Response: nil, - } - - q.Append(&u) - wg.Wait() + upload.Callback.Wait() - if u.Err != nil { - http.Error(*w, u.Err.Error(), http.StatusBadRequest) + if upload.Err != nil { + http.Error(*w, upload.Err.Error(), http.StatusBadRequest) return } - if u.Response.Code != 0 { - http.Error(*w, fmt.Errorf(u.Response.RawLog).Error(), http.StatusBadRequest) + if upload.Response.Code != 0 { + http.Error(*w, fmt.Errorf(upload.Response.RawLog).Error(), http.StatusBadRequest) return } diff --git a/jprov/server/attestation_test.go b/jprov/server/attestation_test.go new file mode 100644 index 0000000..5a88cbe --- /dev/null +++ b/jprov/server/attestation_test.go @@ -0,0 +1,74 @@ +package server_test + +import ( + "testing" + + "github.com/JackalLabs/jackal-provider/jprov/queue" + "github.com/JackalLabs/jackal-provider/jprov/server" + "github.com/JackalLabs/jackal-provider/jprov/testutils" + "github.com/JackalLabs/jackal-provider/jprov/types" + "github.com/stretchr/testify/require" +) + +func TestVerifyAttest(t *testing.T) { + cases := map[string]struct { + attest types.AttestRequest + verified bool + expErr bool + }{ + "wrong proof": { + attest: types.AttestRequest{ + Cid: "-", + Item: "0", + }, + verified: false, + expErr: false, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + testFile := testutils.NewFile([]byte("hello world")) + + c.attest.HashList = string(testFile.GetJsonProof()) + + v, e := server.VerifyAttest(testFile.GenerateActiveDeal(), c.attest) + + if c.verified != v { + t.Log("expected: ", c.verified, " got: ", v) + t.Fail() + } + if !c.expErr && e != nil { + t.Log("expect no error, got: ", e) + t.Fail() + } + }) + } +} + +func TestAddMsgAttest(t *testing.T) { + cases := map[string]struct { + address string + cid string + expErr bool + }{ + "invalid_address": { + address: "invalid_address", + cid: "jklc1dmcul9svpv0z2uzfv30lz0kcjrpdfmmfccskt06wpy8vfqrhp4nsgvgz32", + expErr: true, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + q := queue.UploadQueue{} + _, err := server.AddAttestMsg(c.address, c.cid, &q) + + if c.expErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/jprov/server/file_server.go b/jprov/server/file_server.go index 257d586..7f268c5 100644 --- a/jprov/server/file_server.go +++ b/jprov/server/file_server.go @@ -50,19 +50,7 @@ func saveFile(file multipart.File, handler *multipart.FileHeader, sender string, return err } - cidHash := sha256.New() - - var str strings.Builder // building the FID - str.WriteString(sender) - str.WriteString(address) - str.WriteString(fid) - - _, err = io.WriteString(cidHash, str.String()) - if err != nil { - return err - } - cid := cidHash.Sum(nil) - strCid, err := utils.MakeCid(cid) + cid, err := buildCid(address, sender, fid) if err != nil { return err } @@ -77,28 +65,16 @@ func saveFile(file multipart.File, handler *multipart.FileHeader, sender string, } wg.Wait() - v := types.UploadResponse{ - CID: strCid, - FID: fid, - } - if msg.Err != nil { ctx.Logger.Error(msg.Err.Error()) - v := types.ErrorResponse{ - Error: msg.Err.Error(), - } - (*w).WriteHeader(http.StatusInternalServerError) - err = json.NewEncoder(*w).Encode(v) - } else { - err = json.NewEncoder(*w).Encode(v) } - if err != nil { + if err = writeResponse(*w, *msg, fid, cid); err != nil { ctx.Logger.Error("Json Encode Error: %v", err) return err } - err = utils.SaveToDatabase(fid, strCid, db, ctx.Logger) + err = utils.SaveToDatabase(fid, cid, db, ctx.Logger) if err != nil { return err } @@ -106,6 +82,38 @@ func saveFile(file multipart.File, handler *multipart.FileHeader, sender string, return nil } +func writeResponse(w http.ResponseWriter, upload types.Upload, fid, cid string) error { + if upload.Err != nil { + resp := types.ErrorResponse{ + Error: upload.Err.Error(), + } + return json.NewEncoder(w).Encode(resp) + } + + resp := types.UploadResponse{ + CID: cid, + FID: fid, + } + + return json.NewEncoder(w).Encode(resp) +} + +func buildCid(address, sender, fid string) (string, error) { + h := sha256.New() + + var footprint strings.Builder // building FID + footprint.WriteString(sender) + footprint.WriteString(address) + footprint.WriteString(fid) + + _, err := io.WriteString(h, footprint.String()) + if err != nil { + return "", err + } + + return utils.MakeCid(h.Sum(nil)) +} + func MakeContract(cmd *cobra.Command, fid string, sender string, wg *sync.WaitGroup, q *queue.UploadQueue, merkleroot string, filesize string) (*types.Upload, error) { ctx := utils.GetServerContextFromCmd(cmd) clientCtx, err := client.GetClientTxContext(cmd) diff --git a/jprov/server/file_server_test.go b/jprov/server/file_server_test.go new file mode 100644 index 0000000..d5e67c3 --- /dev/null +++ b/jprov/server/file_server_test.go @@ -0,0 +1,99 @@ +package server_test + +import ( + "crypto/sha256" + "encoding/json" + "errors" + "net/http/httptest" + "testing" + + "github.com/JackalLabs/jackal-provider/jprov/server" + "github.com/JackalLabs/jackal-provider/jprov/types" + "github.com/JackalLabs/jackal-provider/jprov/utils" + "github.com/stretchr/testify/assert" +) + +func TestWriteResponse(t *testing.T) { + cases := map[string]struct { + fid string + cid string + hasMsgErr bool + expErr bool + }{ + "no_error_response": { + fid: "1", + cid: "1", + hasMsgErr: false, + expErr: false, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + rec := httptest.NewRecorder() + + upload := types.Upload{} + + if c.hasMsgErr { + upload.Err = errors.New("example error") + } + + err := server.WriteResponse(rec, upload, c.fid, c.cid) + assert.NoError(t, err) + + resp := types.UploadResponse{ + CID: c.cid, + FID: c.fid, + } + + expResult, err := json.Marshal(resp) + if err != nil { + t.Error(err) + } + + assert.NotNil(t, rec.Body) + // converted to string for easier reading + assert.Equal(t, string(expResult)+"\n", rec.Body.String()) + }) + } +} + +func TestBuildCid(t *testing.T) { + cases := map[string]struct { + address string + sender string + fid string + expErr bool + }{ + "valid_cid": { + address: "example_address", + sender: "example_sender", + fid: "example_fid", + expErr: false, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + cid, err := server.BuildCid(c.address, c.sender, c.fid) + + if c.expErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + footprint := c.sender + c.address + c.fid + + h := sha256.New() + _, err = h.Write([]byte(footprint)) + if err != nil { + t.Error(err) + t.FailNow() + } + expCid, _ := utils.MakeCid(h.Sum(nil)) + + assert.Equal(t, expCid, cid) + }) + } +} diff --git a/jprov/server/proofs.go b/jprov/server/proofs.go index f7030a3..5ec1ec8 100644 --- a/jprov/server/proofs.go +++ b/jprov/server/proofs.go @@ -35,52 +35,58 @@ import ( "github.com/spf13/cobra" ) -func CreateMerkleForProof(clientCtx client.Context, filename string, index int, ctx *utils.Context) (string, string, error) { - files := utils.GetStoragePathForPiece(clientCtx, filename, index) - - item, err := os.ReadFile(files) // read only the chunk we need +func GetMerkleTree(ctx client.Context, filename string) (*merkletree.MerkleTree, error) { + rawTree, err := os.ReadFile(utils.GetStoragePathForTree(ctx, filename)) if err != nil { - ctx.Logger.Error("Error can't open file!") - return "", "", err + return &merkletree.MerkleTree{}, fmt.Errorf("unable to find merkle tree for: %s", filename) } - rawTree, err := os.ReadFile(utils.GetStoragePathForTree(clientCtx, filename)) + return merkletree.ImportMerkleTree(rawTree, sha3.New512()) +} + +func GenerateMerkleProof(tree merkletree.MerkleTree, index int, item []byte) (valid bool, proof *merkletree.Proof, err error) { + h := sha256.New() + _, err = io.WriteString(h, fmt.Sprintf("%d%x", index, item)) if err != nil { - ctx.Logger.Error("Error can't find tree!") - return "", "", err + return } - tree, err := merkletree.ImportMerkleTree(rawTree, sha3.New512()) // import the tree instead of creating the tree on the fly + proof, err = tree.GenerateProof(h.Sum(nil), 0) if err != nil { - ctx.Logger.Error("Error can't import tree!") - return "", "", err + return } - h := sha256.New() - _, err = io.WriteString(h, fmt.Sprintf("%d%x", index, item)) + valid, err = merkletree.VerifyProofUsing(h.Sum(nil), false, proof, [][]byte{tree.Root()}, sha3.New512()) + return +} + +func CreateMerkleForProof(clientCtx client.Context, filename string, index int, ctx *utils.Context) (string, string, error) { + files := utils.GetStoragePathForPiece(clientCtx, filename, index) + + item, err := os.ReadFile(files) // read only the chunk we need if err != nil { + ctx.Logger.Error("Error can't open file!") return "", "", err } - ditem := h.Sum(nil) - proof, err := tree.GenerateProof(ditem, 0) + mTree, err := GetMerkleTree(clientCtx, filename) if err != nil { return "", "", err } - jproof, err := json.Marshal(*proof) + verified, proof, err := GenerateMerkleProof(*mTree, index, item) if err != nil { + ctx.Logger.Error(err.Error()) return "", "", err } - verified, err := merkletree.VerifyProofUsing(ditem, false, proof, [][]byte{tree.Root()}, sha3.New512()) + jproof, err := json.Marshal(*proof) if err != nil { - ctx.Logger.Error(err.Error()) return "", "", err } if !verified { - ctx.Logger.Info("Cannot verify") + ctx.Logger.Info("unable to generate valid proof") } return fmt.Sprintf("%x", item), string(jproof), nil @@ -217,7 +223,7 @@ func requestAttestation(clientCtx client.Context, cid string, hashList string, i pwg.Wait() - if count < 3 { + if count < 3 { // NOTE: this value can change in chain params fmt.Println("failed to get enough attestations...") return fmt.Errorf("failed to get attestations") } diff --git a/jprov/server/proofs_test.go b/jprov/server/proofs_test.go new file mode 100644 index 0000000..dc2e069 --- /dev/null +++ b/jprov/server/proofs_test.go @@ -0,0 +1,53 @@ +package server_test + +import ( + "crypto/sha256" + "fmt" + "io" + "testing" + + "github.com/JackalLabs/jackal-provider/jprov/server" + // "github.com/JackalLabs/jackal-provider/jprov/types" + // "github.com/JackalLabs/jackal-provider/jprov/testutils" + "github.com/stretchr/testify/require" + merkletree "github.com/wealdtech/go-merkletree" + "github.com/wealdtech/go-merkletree/sha3" +) + +func TestGenerateMerkleProof(t *testing.T) { + cases := map[string]struct { + index int + item []byte + expValid bool + expErr bool + }{ + "valid proof": { + index: 0, + item: []byte("hello"), + expValid: true, + expErr: false, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + data := [][]byte{[]byte("hello"), []byte("world")} + for i, item := range data { + h := sha256.New() + _, err := io.WriteString(h, fmt.Sprintf("%d%x", i, item)) + require.NoError(t, err) + data[i] = h.Sum(nil) + } + tree, err := merkletree.NewUsing(data, sha3.New512(), false) + require.NoError(t, err) + + valid, _, err := server.GenerateMerkleProof(*tree, c.index, c.item) + require.EqualValues(t, c.expValid, valid) + if c.expErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/jprov/server/server_export_test.go b/jprov/server/server_export_test.go new file mode 100644 index 0000000..edc750c --- /dev/null +++ b/jprov/server/server_export_test.go @@ -0,0 +1,8 @@ +package server + +var ( + VerifyAttest = verifyAttest + AddAttestMsg = addMsgAttest + BuildCid = buildCid + WriteResponse = writeResponse +) diff --git a/jprov/testutils/testFile.go b/jprov/testutils/testFile.go new file mode 100644 index 0000000..b676527 --- /dev/null +++ b/jprov/testutils/testFile.go @@ -0,0 +1,57 @@ +package testutils + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + + storagetypes "github.com/jackalLabs/canine-chain/x/storage/types" + + "github.com/wealdtech/go-merkletree" + "github.com/wealdtech/go-merkletree/sha3" +) + +// File with a single block +type MerkleFile struct { + data []byte + tree merkletree.MerkleTree +} + +func NewFile(data []byte) MerkleFile { + h := sha256.New() + _, err := io.WriteString(h, fmt.Sprintf("%d%x", 0, data)) + if err != nil { + panic(err) + } + + raw := [][]byte{h.Sum(nil)} + + tree, err := merkletree.NewUsing(raw, sha3.New512(), false) + if err != nil { + panic(err) + } + + return MerkleFile{data: h.Sum(nil), tree: *tree} +} + +func (m *MerkleFile) GetProof() merkletree.Proof { + proof, err := m.tree.GenerateProof(m.data, 0) + if err != nil { + panic(err) + } + return *proof +} + +func (m *MerkleFile) GetJsonProof() []byte { + proof, err := json.Marshal(m.GetProof()) + if err != nil { + panic(err) + } + return proof +} + +func (m *MerkleFile) GenerateActiveDeal() storagetypes.ActiveDeals { + return storagetypes.ActiveDeals{Blocktoprove: "0", Merkle: hex.EncodeToString(m.tree.Root())} +}