Skip to content

Commit

Permalink
Merge pull request #58 from AvaProtocol/continue-impl-for-other-node-…
Browse files Browse the repository at this point in the history
…execution

Continue impl for other node type execution, support  triggering
  • Loading branch information
v9n authored Dec 12, 2024
2 parents 97c45a4 + e8dc6bb commit b88c795
Show file tree
Hide file tree
Showing 28 changed files with 2,313 additions and 782 deletions.
24 changes: 10 additions & 14 deletions aggregator/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a

if strings.Contains(payload.Signature, ".") {
authenticated, err := auth.VerifyJwtKeyForUser(r.config.JwtSecret, payload.Signature, submitAddress)
if err != nil {
return nil, err
}

if !authenticated {
return nil, auth.ErrorUnAuthorized
if err != nil || !authenticated {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, auth.InvalidAPIKey)
}
} else {
// We need to have 3 things to verify the signature: the signature, the hash of the original data, and the public key of the signer. With this information we can determine if the private key holder of the public key pair did indeed sign the message
Expand All @@ -66,10 +62,10 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
sigPublicKey, err := crypto.SigToPub(hash, signature)
recoveredAddr := crypto.PubkeyToAddress(*sigPublicKey)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}
if submitAddress.String() != recoveredAddr.String() {
return nil, fmt.Errorf("Invalid signature")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}
}

Expand All @@ -83,7 +79,7 @@ func (r *RpcServer) GetKey(ctx context.Context, payload *avsproto.GetKeyReq) (*a
ss, err := token.SignedString(r.config.JwtSecret)

if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, InternalError)
}

return &avsproto.KeyResp{
Expand Down Expand Up @@ -114,7 +110,7 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

// hmacSampleSecret is a []byte containing your
Expand All @@ -123,16 +119,16 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
})

if err != nil {
return nil, err
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

if token.Header["alg"] != auth.JwtAlg {
return nil, fmt.Errorf("invalid signing algorithm")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

if claims, ok := token.Claims.(jwt.MapClaims); ok {
if claims["sub"] == "" {
return nil, fmt.Errorf("Missing subject")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

user := model.User{
Expand All @@ -155,7 +151,7 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {

return &user, nil
}
return nil, fmt.Errorf("Malform claims")
return nil, fmt.Errorf("%s", auth.InvalidAuthenticationKey)
}

// verifyOperator checks validity of the signature submit by operator related request
Expand Down
5 changes: 5 additions & 0 deletions aggregator/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package aggregator

const (
InternalError = "Internal Error"
)
44 changes: 32 additions & 12 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ type RpcServer struct {
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) CreateWallet(ctx context.Context, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
func (r *RpcServer) GetWallet(ctx context.Context, payload *avsproto.GetWalletReq) (*avsproto.GetWalletResp, error) {
user, err := r.verifyAuth(ctx)

if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}
r.config.Logger.Info("process create wallet",
"user", user.Address.String(),
Expand Down Expand Up @@ -76,7 +76,7 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest
func (r *RpcServer) ListWallets(ctx context.Context, payload *avsproto.ListWalletReq) (*avsproto.ListWalletResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list wallet",
Expand All @@ -88,19 +88,19 @@ func (r *RpcServer) ListWallets(ctx context.Context, payload *avsproto.ListWalle
}

return &avsproto.ListWalletResp{
Wallets: wallets,
Items: wallets,
}, nil
}

func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.IdReq) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process cancel task",
"user", user.Address.String(),
"taskID", taskID.Id,
"task_id", taskID.Id,
)

result, err := r.engine.CancelTaskByUser(user, string(taskID.Id))
Expand All @@ -115,12 +115,12 @@ func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.IdReq) (*wr
func (r *RpcServer) DeleteTask(ctx context.Context, taskID *avsproto.IdReq) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process delete task",
"user", user.Address.String(),
"taskID", string(taskID.Id),
"task_id", string(taskID.Id),
)

result, err := r.engine.DeleteTaskByUser(user, string(taskID.Id))
Expand Down Expand Up @@ -151,7 +151,7 @@ func (r *RpcServer) CreateTask(ctx context.Context, taskPayload *avsproto.Create
func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksReq) (*avsproto.ListTasksResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list task",
Expand All @@ -164,7 +164,7 @@ func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksRe
func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListExecutionsReq) (*avsproto.ListExecutionsResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process list execution",
Expand All @@ -177,12 +177,12 @@ func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListEx
func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsproto.Task, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.InvalidAuthenticationKey, err.Error())
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process get task",
"user", user.Address.String(),
"taskID", payload.Id,
"task_id", payload.Id,
)

if payload.Id == "" {
Expand All @@ -197,6 +197,26 @@ func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsp
return task.ToProtoBuf()
}

// TriggerTask emit a trigger event that cause the task to be queue and execute eventually. It's similar to a trigger
// sending by operator, but in this case the user manually provide a trigger point to force run it.
func (r *RpcServer) TriggerTask(ctx context.Context, payload *avsproto.UserTriggerTaskReq) (*avsproto.UserTriggerTaskResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process trigger task",
"user", user.Address.String(),
"task_id", payload.TaskId,
)

if payload.TaskId == "" {
return nil, status.Errorf(codes.InvalidArgument, taskengine.TaskIDMissing)
}

return r.engine.TriggerTask(user, payload)
}

// Operator action
func (r *RpcServer) SyncMessages(payload *avsproto.SyncMessagesReq, srv avsproto.Node_SyncMessagesServer) error {
err := r.engine.StreamCheckToOperator(payload, srv)
Expand Down
2 changes: 2 additions & 0 deletions aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/AvaProtocol/ap-avs/core/apqueue"
"github.com/AvaProtocol/ap-avs/core/taskengine"
"github.com/AvaProtocol/ap-avs/core/taskengine/macros"
)

func (agg *Aggregator) stopTaskEngine() {
Expand All @@ -26,6 +27,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
taskExecutor := taskengine.NewExecutor(agg.db, agg.logger)
taskengine.SetMacro(agg.config.Macros)
taskengine.SetCache(agg.cache)
macros.SetRpc(agg.config.SmartWallet.EthRpcUrl)

agg.worker.RegisterProcessor(
taskengine.ExecuteTask,
Expand Down
8 changes: 4 additions & 4 deletions core/apqueue/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (w *Worker) loop() {
for {
select {
case jid := <-w.q.eventCh:
w.logger.Info("process job from queue", "jobid", jid)
w.logger.Info("process job from queue", "job_id", jid)
job, err := w.q.Dequeue()
if err != nil {
w.logger.Error("failed to dequeue", "error", err)
Expand All @@ -53,15 +53,15 @@ func (w *Worker) loop() {
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "jobid", jid, "jobName", job.Name, "jobdata", string(job.Data))
w.logger.Info("decoded job", "job_id", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "jobid", jid, "task_id", job.Name)
w.logger.Info("succesfully perform job", "job_id", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Errorf("failed to perform job %w", err, "jobid", jid, "task_id", job.Name)
w.logger.Error("failed to perform job", "error", err, "job_id", jid, "task_id", job.Name)
}
case <-w.q.closeCh: // loop was stopped
return
Expand Down
4 changes: 3 additions & 1 deletion core/auth/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package auth

const (
InvalidAuthenticationKey = "Invalid authentication key"
AuthenticationError = "User authentication error"
InvalidSignatureFormat = "Invalid Signature Format"
InvalidAuthenticationKey = "User Auth key is invalid"
InvalidAPIKey = "API key is invalid"
)
Loading

0 comments on commit b88c795

Please sign in to comment.