Skip to content

Commit

Permalink
handle craned unexpect down for crun task
Browse files Browse the repository at this point in the history
Signed-off-by: L-Xiafeng <xiafeng.li@foxmail.com>
  • Loading branch information
L-Xiafeng committed Jun 21, 2024
1 parent b162796 commit 3abb8fb
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 25 deletions.
103 changes: 81 additions & 22 deletions internal/cfored/cfored.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ type GlobalVariables struct {
// Used by Cfored <--> Ctld state machine to de-multiplex messages from CraneCtld.
// Cfored <--> Ctld state machine GUARANTEES that NO `nil` will be sent into these channels.
// Used for calloc/crun with task id allocated.
//taskid -> pid_list
ctldReplyChannelMapByTaskIdProcId map[uint32] /*TaskId*/ map[uint32] /*ProcId*/ chan *protos.StreamCtldReply
//ctldReplyChannelMapByProcId map[uint32]chan *protos.StreamCtldReply

// Used by Calloc/Crun <--> Cfored state machine to multiplex messages
// these messages will be sent to CraneCtld
Expand Down Expand Up @@ -91,8 +89,9 @@ const (
CrunWaitIOForward StateOfCrunServer = 3
CrunWaitTaskComplete StateOfCrunServer = 4
CrunWaitTaskCancel StateOfCrunServer = 5
CrunForwardEnd StateOfCrunServer = 6
CancelTaskOfDeadCrun StateOfCrunServer = 7
CrunWaitForwardEnd StateOfCrunServer = 6
CrunForwardEnd StateOfCrunServer = 7
CancelTaskOfDeadCrun StateOfCrunServer = 8
)

func (cforedServer *GrpcCforedServer) CrunStream(toCrunStream protos.CraneForeD_CrunStreamServer) error {
Expand All @@ -103,6 +102,7 @@ func (cforedServer *GrpcCforedServer) CrunStream(toCrunStream protos.CraneForeD_
var reply *protos.StreamCforedCrunReply

var execCranedIds []string
cranedNum := atomic.Uint32{}
crunRequestChannel := make(chan grpcMessage[protos.StreamCrunRequest], 8)
go grpcStreamReceiver[protos.StreamCrunRequest](toCrunStream, crunRequestChannel)

Expand Down Expand Up @@ -275,7 +275,7 @@ CforedCrunStateMachineLoop:
if crunRequest != nil || err == nil {
log.Fatal("[Cfored<->Crun] Expect only nil (crun connection broken) here!")
}
log.Debug("[Cfored<->Crun] Connection to calloc was broken.")
log.Debug("[Cfored<->Crun] Connection to crun was broken.")

state = CancelTaskOfDeadCrun

Expand Down Expand Up @@ -315,10 +315,11 @@ CforedCrunStateMachineLoop:
case CrunWaitIOForward:
log.Debug("[Cfored<->Crun] Enter State WAIT_TASK_IO_FORWARD")

cranedNum.Store(uint32(len(execCranedIds)))
stopWaiting := atomic.Bool{}
stopWaiting.Store(false)
readyChannel := make(chan bool, 1)
go gCranedChanKeeper.waitCranedChannelsReady(execCranedIds, readyChannel, &stopWaiting)
go gCranedChanKeeper.waitCranedChannelsReady(execCranedIds, readyChannel, &stopWaiting, taskId, procId)

select {
case ctldReply := <-ctldReplyChannel:
Expand Down Expand Up @@ -382,8 +383,6 @@ CforedCrunStateMachineLoop:

case CrunWaitTaskComplete:
log.Debug("[Cfored<->Crun] Enter State Crun_Wait_Task_Complete")
cranedNum := atomic.Uint32{}
cranedNum.Store(uint32(len(execCranedIds)))
forwarding:
for {
select {
Expand Down Expand Up @@ -433,6 +432,13 @@ CforedCrunStateMachineLoop:
}

case taskMsg := <-TaskIoRequestChannel:
if taskMsg == nil {
log.Errorf("[Cfored<->Crun] Task %d Proc %d Craned down,cancel task", taskId, procId)
//connection err
cranedNum.Store(cranedNum.Load() - 1)
state = CrunWaitTaskCancel
break forwarding
}
if taskMsg.Type == protos.StreamCforedTaskIORequest_CRANED_PROC_OUTPUT {
if taskMsg.GetPayloadProcOutputReq().End {
num := cranedNum.Load() - 1
Expand All @@ -443,22 +449,20 @@ CforedCrunStateMachineLoop:
cranedNum.Store(num)
continue
}

} else {
reply = &protos.StreamCforedCrunReply{
Type: protos.StreamCforedCrunReply_TASK_IO_FORWARD,
Payload: &protos.StreamCforedCrunReply_PayloadTaskIoForwardReply{
PayloadTaskIoForwardReply: &protos.StreamCforedCrunReply_TaskIOForwardReply{
Msg: taskMsg.GetPayloadProcOutputReq().Msg,
},
}
reply = &protos.StreamCforedCrunReply{
Type: protos.StreamCforedCrunReply_TASK_IO_FORWARD,
Payload: &protos.StreamCforedCrunReply_PayloadTaskIoForwardReply{
PayloadTaskIoForwardReply: &protos.StreamCforedCrunReply_TaskIOForwardReply{
Msg: taskMsg.GetPayloadProcOutputReq().Msg,
},
}
log.Tracef("[Cfored<->Crun] fowarding msg %s to crun for taskid %d Proc %d", taskMsg.GetPayloadProcOutputReq().Msg, taskId, procId)
},
}
log.Tracef("[Cfored<->Crun] fowarding msg %s to crun for taskid %d Proc %d", taskMsg.GetPayloadProcOutputReq().Msg, taskId, procId)

if err := toCrunStream.Send(reply); err != nil {
log.Debugf("[Cfored<->Crun] Failed to send Request to crun: %s. "+
"The connection to calloc was broken.", err.Error())
"The connection to crun was broken.", err.Error())
state = CancelTaskOfDeadCrun
break forwarding
}
Expand Down Expand Up @@ -516,7 +520,62 @@ CforedCrunStateMachineLoop:
}
gVars.cforedRequestCtldChannel <- toCtldRequest

state = CrunForwardEnd
state = CrunWaitForwardEnd
}
}

case CrunWaitForwardEnd:
log.Debug("[Cfored<->Crun] Enter State Crun_Wait_Forward_End")
if cranedNum.Load() == 0 {
state = CrunForwardEnd
break
}
WaitForwardEnd:
for {
select {
case taskMsg := <-TaskIoRequestChannel:
if taskMsg == nil {
num := cranedNum.Load() - 1
if num == 0 {
state = CrunForwardEnd
break WaitForwardEnd
} else {
cranedNum.Store(num)
continue
}
}
if taskMsg.Type == protos.StreamCforedTaskIORequest_CRANED_PROC_OUTPUT {
if taskMsg.GetPayloadProcOutputReq().End {
num := cranedNum.Load() - 1
if num == 0 {
state = CrunForwardEnd
break WaitForwardEnd
} else {
cranedNum.Store(num)
continue
}
}

reply = &protos.StreamCforedCrunReply{
Type: protos.StreamCforedCrunReply_TASK_IO_FORWARD,
Payload: &protos.StreamCforedCrunReply_PayloadTaskIoForwardReply{
PayloadTaskIoForwardReply: &protos.StreamCforedCrunReply_TaskIOForwardReply{
Msg: taskMsg.GetPayloadProcOutputReq().Msg,
},
},
}
log.Tracef("[Cfored<->Crun] fowarding msg %s to crun for taskid %d Proc %d", taskMsg.GetPayloadProcOutputReq().Msg, taskId, procId)

if err := toCrunStream.Send(reply); err != nil {
log.Debugf("[Cfored<->Crun] Failed to send Request to crun: %s. "+
"The connection to crun was broken.", err.Error())
state = CancelTaskOfDeadCrun
break WaitForwardEnd
}
} else {
log.Fatal("[Cfored<->Crun] Expect Type CRANED_TASK_OUTPUTT")
break WaitForwardEnd
}
}
}

Expand All @@ -542,7 +601,7 @@ CforedCrunStateMachineLoop:
}
gVars.ctldReplyChannelMapMtx.Unlock()

gCranedChanKeeper.removeRemoteIoToCrunChannel(taskId, procId)
gCranedChanKeeper.crunEnd(taskId, procId, execCranedIds)
break CforedCrunStateMachineLoop

case CancelTaskOfDeadCrun:
Expand Down Expand Up @@ -571,7 +630,7 @@ CforedCrunStateMachineLoop:
delete(gVars.pidTaskIdMap, crunPid)
gVars.pidTaskIdMapMtx.Unlock()

gCranedChanKeeper.removeRemoteIoToCrunChannel(taskId, procId)
gCranedChanKeeper.crunEnd(taskId, procId, execCranedIds)
} else {
delete(gVars.ctldReplyChannelMapByPid, crunPid)
}
Expand Down
62 changes: 59 additions & 3 deletions internal/cfored/grpcServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ import (
"syscall"
)

type void struct{}

var voidMember void

type CranedChannelKeeper struct {
crunRequestChannelMtx sync.Mutex
crunRequestChannelCV *sync.Cond
Expand All @@ -42,6 +46,10 @@ type CranedChannelKeeper struct {
taskIORequestChannelMtx sync.Mutex
// I/O message from Craned to Crun
taskIORequestChannelMapByTaskIdProcId map[uint32] /*taskId*/ map[uint32] /*ProcId*/ chan *protos.StreamCforedTaskIORequest

//for error handle
taskIdProcIdMapMtx sync.Mutex
taskIdProcIdMapByCraned map[string] /*cranedId*/ map[uint32] /*taskId*/ map[uint32] /*procId*/ void
}

var gCranedChanKeeper *CranedChannelKeeper
Expand All @@ -51,6 +59,7 @@ func NewCranedChannelKeeper() *CranedChannelKeeper {
keeper.crunRequestChannelCV = sync.NewCond(&keeper.crunRequestChannelMtx)
keeper.crunRequestChannelMapByCranedId = make(map[string]chan *protos.StreamCrunRequest)
keeper.taskIORequestChannelMapByTaskIdProcId = make(map[uint32]map[uint32]chan *protos.StreamCforedTaskIORequest)
keeper.taskIdProcIdMapByCraned = make(map[string]map[uint32]map[uint32]void)
return keeper
}

Expand All @@ -68,7 +77,7 @@ func (keeper *CranedChannelKeeper) cranedDownAndRemoveChannelToCraned(cranedId s
keeper.crunRequestChannelMtx.Unlock()
}

func (keeper *CranedChannelKeeper) waitCranedChannelsReady(cranedIds []string, readyChan chan bool, stopWaiting *atomic.Bool) {
func (keeper *CranedChannelKeeper) waitCranedChannelsReady(cranedIds []string, readyChan chan bool, stopWaiting *atomic.Bool, taskId uint32, procId uint32) {
keeper.crunRequestChannelMtx.Lock()
defer keeper.crunRequestChannelMtx.Unlock()
for !stopWaiting.Load() {
Expand All @@ -86,12 +95,45 @@ func (keeper *CranedChannelKeeper) waitCranedChannelsReady(cranedIds []string, r
// Once Wait() returns, the lock is held again.
} else {
log.Debug("[Cfored<->Crun] All related craned up now")
keeper.taskIdProcIdMapMtx.Lock()
for _, cranedId := range cranedIds {
if _, exist := keeper.taskIdProcIdMapByCraned[cranedId]; !exist {
keeper.taskIdProcIdMapByCraned[cranedId] = make(map[uint32]map[uint32]void)
}
proc := make(map[uint32]void)
proc[procId] = voidMember
keeper.taskIdProcIdMapByCraned[cranedId][taskId] = proc
}
keeper.taskIdProcIdMapMtx.Unlock()
readyChan <- true
break
}
}
}

func (keeper *CranedChannelKeeper) UnexpectedCranedDown(cranedId string) {
keeper.taskIdProcIdMapMtx.Lock()
defer keeper.taskIdProcIdMapMtx.Unlock()
if _, exist := keeper.taskIdProcIdMapByCraned[cranedId]; !exist {
log.Errorf("Ignoring unexist craned %s unexpect down", cranedId)
} else {
keeper.taskIORequestChannelMtx.Lock()
for taskId, procIds := range keeper.taskIdProcIdMapByCraned[cranedId] {

for procId, _ := range procIds {
channel, exist := keeper.taskIORequestChannelMapByTaskIdProcId[taskId][procId]
if exist {
channel <- nil
} else {
log.Warningf("Trying forward to I/O to an unknown crun of task #%d proc.#%d", taskId, procId)
}
}

}
keeper.taskIORequestChannelMtx.Unlock()
}

}
func (keeper *CranedChannelKeeper) forwardCrunRequestToCranedChannels(request *protos.StreamCrunRequest, cranedIds []string) {
keeper.crunRequestChannelMtx.Lock()
for _, node := range cranedIds {
Expand Down Expand Up @@ -120,13 +162,26 @@ func (keeper *CranedChannelKeeper) forwardRemoteIoToCrun(taskId uint32, procId u
keeper.taskIORequestChannelMtx.Unlock()
}

func (keeper *CranedChannelKeeper) removeRemoteIoToCrunChannel(taskId uint32, procId uint32) {
func (keeper *CranedChannelKeeper) crunEnd(taskId uint32, procId uint32, cranedIds []string) {
keeper.taskIORequestChannelMtx.Lock()
delete(keeper.taskIORequestChannelMapByTaskIdProcId[taskId], procId)
if len(keeper.taskIORequestChannelMapByTaskIdProcId[taskId]) == 0 {
delete(keeper.taskIORequestChannelMapByTaskIdProcId, taskId)
}

keeper.taskIdProcIdMapMtx.Lock()
for _, cranedId := range cranedIds {
if _, exist := keeper.taskIdProcIdMapByCraned[cranedId]; !exist {
log.Errorf("CranedId %s should exist in CranedChannelKeeper", cranedId)
} else {
delete(keeper.taskIdProcIdMapByCraned[cranedId][taskId], procId)
if len(keeper.taskIdProcIdMapByCraned[cranedId][taskId]) == 0 {
delete(keeper.taskIdProcIdMapByCraned[cranedId], taskId)
}
}

}
keeper.taskIdProcIdMapMtx.Unlock()
keeper.taskIORequestChannelMtx.Unlock()
}

Expand Down Expand Up @@ -254,12 +309,13 @@ CforedCranedStateMachineLoop:
// Msg from craned
cranedReq, err := item.message, item.err
if err != nil { // Failure Edge
// Todo: do something when craned down
switch err {
case io.EOF:

fallthrough
default:
log.Errorf("[Cfored<->Craned] Craned %s unexpected down", cranedId)
gCranedChanKeeper.UnexpectedCranedDown(cranedId)
state = CranedUnReg
break cranedIOForwarding
}
Expand Down

0 comments on commit 3abb8fb

Please sign in to comment.