Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update protobuf method for feedback #59

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions aggregator/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a

if strings.Contains(payload.Signature, ".") {
authenticated, err := auth.VerifyJwtKeyForUser(r.config.JwtSecret, payload.Signature, submitAddress)
if err != nil {
return nil, err
}

if !authenticated {
return nil, auth.ErrorUnAuthorized
if err != nil || !authenticated {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, auth.InvalidAPIKey)
}
} else {
// We need to have 3 things to verify the signature: the signature, the hash of the original data, and the public key of the signer. With this information we can determine if the private key holder of the public key pair did indeed sign the message
Expand All @@ -66,10 +62,10 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
sigPublicKey, err := crypto.SigToPub(hash, signature)
recoveredAddr := crypto.PubkeyToAddress(*sigPublicKey)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}
if submitAddress.String() != recoveredAddr.String() {
return nil, fmt.Errorf("Invalid signature")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}
}

Expand All @@ -83,7 +79,7 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
ss, err := token.SignedString(r.config.JwtSecret)

if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, InternalError)
}

return &avsproto.KeyResp{
Expand Down Expand Up @@ -114,7 +110,7 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

// hmacSampleSecret is a []byte containing your
Expand All @@ -123,16 +119,16 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
})

if err != nil {
return nil, err
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

if token.Header["alg"] != auth.JwtAlg {
return nil, fmt.Errorf("invalid signing algorithm")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

if claims, ok := token.Claims.(jwt.MapClaims); ok {
if claims["sub"] == "" {
return nil, fmt.Errorf("Missing subject")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

user := model.User{
Expand All @@ -155,7 +151,7 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {

return &user, nil
}
return nil, fmt.Errorf("Malform claims")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

// verifyOperator checks validity of the signature submit by operator related request
Expand Down
5 changes: 5 additions & 0 deletions aggregator/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package aggregator

const (
InternalError = "Internal Error"
)
20 changes: 10 additions & 10 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ type RpcServer struct {
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) CreateWallet(ctx context.Context, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
func (r *RpcServer) GetWallet(ctx context.Context, payload *avsproto.GetWalletReq) (*avsproto.GetWalletResp, error) {
user, err := r.verifyAuth(ctx)

if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}
r.config.Logger.Info("process create wallet",
"user", user.Address.String(),
Expand Down Expand Up @@ -76,7 +76,7 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest
func (r *RpcServer) ListWallets(ctx context.Context, payload *avsproto.ListWalletReq) (*avsproto.ListWalletResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list wallet",
Expand All @@ -88,14 +88,14 @@ func (r *RpcServer) ListWallets(ctx context.Context, payload *avsproto.ListWalle
}

return &avsproto.ListWalletResp{
Wallets: wallets,
Items: wallets,
}, nil
}

func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.IdReq) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process cancel task",
Expand All @@ -115,7 +115,7 @@ func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.IdReq) (*wr
func (r *RpcServer) DeleteTask(ctx context.Context, taskID *avsproto.IdReq) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process delete task",
Expand Down Expand Up @@ -151,7 +151,7 @@ func (r *RpcServer) CreateTask(ctx context.Context, taskPayload *avsproto.Create
func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksReq) (*avsproto.ListTasksResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list task",
Expand All @@ -164,7 +164,7 @@ func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksRe
func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListExecutionsReq) (*avsproto.ListExecutionsResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list execution",
Expand All @@ -177,7 +177,7 @@ func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListEx
func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsproto.Task, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process get task",
Expand All @@ -202,7 +202,7 @@ func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsp
func (r *RpcServer) TriggerTask(ctx context.Context, payload *avsproto.UserTriggerTaskReq) (*avsproto.UserTriggerTaskResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process trigger task",
Expand Down
4 changes: 3 additions & 1 deletion core/auth/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package auth

const (
InvalidAuthenticationKey = "Invalid authentication key"
AuthenticationError = "User authentication error"
InvalidSignatureFormat = "Invalid Signature Format"
InvalidAuthenticationKey = "User Auth key is invalid"
InvalidAPIKey = "API key is invalid"
)
18 changes: 9 additions & 9 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (n *Engine) GetSmartWallets(owner common.Address) ([]*avsproto.SmartWallet,
return wallets, nil
}

func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.GetWalletReq) (*avsproto.GetWalletResp, error) {
// Verify data
// when user passing a custom factory address, we want to validate it
if payload.FactoryAddress != "" && !common.IsHexAddress(payload.FactoryAddress) {
Expand Down Expand Up @@ -250,7 +250,7 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), StorageWriteError)
}

return &avsproto.CreateWalletResp{
return &avsproto.GetWalletResp{
Address: sender.Hex(),
Salt: salt.String(),
FactoryAddress: factoryAddress.Hex(),
Expand Down Expand Up @@ -414,7 +414,7 @@ func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksRe
}

taskResp := &avsproto.ListTasksResp{
Tasks: []*avsproto.ListTasksResp_Item{},
Items: []*avsproto.ListTasksResp_Item{},
Cursor: "",
}

Expand Down Expand Up @@ -445,7 +445,7 @@ func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksRe
task.Id = taskID

if t, err := task.ToProtoBuf(); err == nil {
taskResp.Tasks = append(taskResp.Tasks, &avsproto.ListTasksResp_Item{
taskResp.Items = append(taskResp.Items, &avsproto.ListTasksResp_Item{
Id: t.Id,
Owner: t.Owner,
SmartWalletAddress: t.SmartWalletAddress,
Expand All @@ -468,7 +468,7 @@ func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksRe
}

if total >= itemPerPage {
taskResp.Cursor = NewCursor(CursorDirectionNext, taskResp.Tasks[total-1].Id).String()
taskResp.Cursor = NewCursor(CursorDirectionNext, taskResp.Items[total-1].Id).String()
}

return taskResp, nil
Expand Down Expand Up @@ -573,8 +573,8 @@ func (n *Engine) ListExecutions(user *model.User, payload *avsproto.ListExecutio
}

executioResp := &avsproto.ListExecutionsResp{
Executions: []*avsproto.Execution{},
Cursor: "",
Items: []*avsproto.Execution{},
Cursor: "",
}

total := 0
Expand All @@ -592,7 +592,7 @@ func (n *Engine) ListExecutions(user *model.User, payload *avsproto.ListExecutio

exec := avsproto.Execution{}
if err := protojson.Unmarshal(kv.Value, &exec); err == nil {
executioResp.Executions = append(executioResp.Executions, &exec)
executioResp.Items = append(executioResp.Items, &exec)
total += 1
}
if total >= itemPerPage {
Expand All @@ -601,7 +601,7 @@ func (n *Engine) ListExecutions(user *model.User, payload *avsproto.ListExecutio
}

if total >= itemPerPage {
executioResp.Cursor = NewCursor(CursorDirectionNext, executioResp.Executions[total-1].Id).String()
executioResp.Cursor = NewCursor(CursorDirectionNext, executioResp.Items[total-1].Id).String()
}
return executioResp, nil
}
Expand Down
6 changes: 3 additions & 3 deletions examples/example.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ async function listTask(owner, token) {
},
metadata
);
console.log(`Found ${result.tasks.length} tasks created by`, process.argv[3]);
console.log(`Found ${result.items.length} tasks created by`, process.argv[3]);

for (const item of result.tasks) {
for (const item of result.items) {
console.log(util.inspect(item, { depth: 4, colors: true }));
}
console.log(util.inspect({cursor: result.cursor}, { depth: 4, colors: true }));
Expand Down Expand Up @@ -225,7 +225,7 @@ async function getWallets(owner, token) {
const tokenContract = new ethers.Contract(tokenAddress, tokenAbi, provider);

let wallets = [];
for (const wallet of walletsResp.wallets) {
for (const wallet of walletsResp.items) {
const balance = await provider.getBalance(wallet.address);
const balanceInEth = _.floor(ethers.formatEther(balance), 2);

Expand Down
Loading
Loading