diff --git a/session/osc.go b/session/osc.go index 9d2c0577..d869e85d 100644 --- a/session/osc.go +++ b/session/osc.go @@ -200,7 +200,27 @@ func (s *session) mysqlExecuteAlterTableOsc(r *Record) { buf.WriteString("'") str := buf.String() - s.execCommand(r, "", "sh", []string{"-c", str}) + _ = s.execCommand(r, "", "sh", []string{"-c", str}) +} + +// getSocketFile return gh-ost socket file +// unix socket file max 104 characters (or 107) +func (s *session) getSocketFile(r *Record) string { + socketFile := fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", s.opt.Host, s.opt.Port, + r.TableInfo.Schema, r.TableInfo.Name) + if len(socketFile) > 100 { + // 字符串过长时转换为hash值 + host := truncateString(s.opt.Host, 30) + dbName := truncateString(r.TableInfo.Schema, 30) + tableName := truncateString(r.TableInfo.Name, 30) + socketFile = fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", host, s.opt.Port, + dbName, tableName) + if len(socketFile) > 100 { + socketFile = fmt.Sprintf("/tmp/gh%s%d%s%s.sock", host, s.opt.Port, + dbName, tableName) + } + } + return socketFile } func (s *session) mysqlExecuteWithGhost(r *Record) { @@ -229,7 +249,7 @@ func (s *session) mysqlExecuteWithGhost(r *Record) { buf.WriteString("\" ") if s.osc.OscPrintSql { - buf.WriteString(" --print ") + buf.WriteString(" --verbose ") } // RDS数据库需要做特殊处理 @@ -284,23 +304,10 @@ func (s *session) mysqlExecuteWithGhost(r *Record) { buf.WriteString(fmt.Sprintf(" --postpone-cut-over-flag-file=%s", s.ghost.GhostPostponeCutOverFlagFile)) buf.WriteString(fmt.Sprintf(" --initially-drop-socket-file=%t", s.ghost.GhostInitiallyDropSocketFile)) - socketFile := fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", s.opt.Host, s.opt.Port, - r.TableInfo.Schema, r.TableInfo.Name) - if len(socketFile) > 100 { - // 字符串过长时转换为hash值 - host := truncateString(s.opt.Host, 30) - dbName := truncateString(r.TableInfo.Schema, 30) - tableName := truncateString(r.TableInfo.Name, 30) - socketFile = fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", host, s.opt.Port, - dbName, tableName) - if len(socketFile) > 100 { - socketFile = fmt.Sprintf("/tmp/gh%s%d%s%s.sock", host, s.opt.Port, - dbName, tableName) - } - } - + // unix socket file max 104 characters (or 107) + socketFile := s.getSocketFile(r) if _, err := os.Stat(socketFile); err == nil { - s.appendErrorMessage("listen unix socket file already in use") + s.appendErrorMessage("listen unix socket file already in use, need to clean up manually") return } else if err != nil && !strings.Contains(err.Error(), "no such file or directory") { log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) @@ -358,7 +365,12 @@ func (s *session) mysqlExecuteWithGhost(r *Record) { str := buf.String() // log.Info(str) - s.execCommand(r, socketFile, "sh", []string{"-c", str}) + if err := s.execCommand(r, socketFile, "sh", []string{"-c", str}); err != nil { + // 当失败时自动清理socket文件 + if !strings.Contains(err.Error(), "file already in use") { + os.Remove(socketFile) + } + } } func (s *session) mysqlExecuteAlterTableGhost(r *Record) { @@ -621,22 +633,14 @@ func (s *session) mysqlExecuteAlterTableGhost(r *Record) { } if migrationContext.ServeSocketFile == "" { // unix socket file max 104 characters (or 107) - socketFile := fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", s.opt.Host, s.opt.Port, - migrationContext.DatabaseName, migrationContext.OriginalTableName) - if len(socketFile) > 100 { - // 字符串过长时转换为hash值 - host := truncateString(s.opt.Host, 30) - dbName := truncateString(migrationContext.DatabaseName, 30) - tableName := truncateString(migrationContext.OriginalTableName, 30) - - socketFile = fmt.Sprintf("/tmp/gh-ost.%s.%d.%s.%s.sock", host, s.opt.Port, - dbName, tableName) - - if len(socketFile) > 100 { - socketFile = fmt.Sprintf("/tmp/gh%s%d%s%s.sock", host, s.opt.Port, - dbName, tableName) - } - + socketFile := s.getSocketFile(r) + if _, err := os.Stat(socketFile); err == nil { + s.appendErrorMessage("listen unix socket file already in use, need to clean up manually") + return + } else if err != nil && !strings.Contains(err.Error(), "no such file or directory") { + log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) + s.appendErrorMessage(err.Error()) + return } migrationContext.ServeSocketFile = socketFile } @@ -780,7 +784,7 @@ func (s *session) mysqlExecuteAlterTableGhost(r *Record) { } } -func (s *session) execCommand(r *Record, socketFile string, commandName string, params []string) bool { +func (s *session) execCommand(r *Record, socketFile string, commandName string, params []string) error { //函数返回一个*Cmd,用于使用给出的参数执行name指定的程序 cmd := exec.Command(commandName, params...) @@ -791,13 +795,13 @@ func (s *session) execCommand(r *Record, socketFile string, commandName string, if err != nil { s.appendErrorMessage(err.Error()) log.Error(err) - return false + return err } stderr, err := cmd.StderrPipe() if err != nil { s.appendErrorMessage(err.Error()) log.Error(err) - return false + return err } // 保证关闭输出流 @@ -808,7 +812,7 @@ func (s *session) execCommand(r *Record, socketFile string, commandName string, if err := cmd.Start(); err != nil { s.appendErrorMessage(err.Error()) log.Error(err) - return false + return err } p := &util.OscProcessInfo{ @@ -893,7 +897,10 @@ func (s *session) execCommand(r *Record, socketFile string, commandName string, log.Error(err) } - if p.Percent < 100 || s.hasError() { + close(p.PanicAbort) + + allMessage := buf.String() + if p.Percent < 100 || s.myRecord.ErrLevel == 2 { s.recordSets.MaxLevel = 2 r.ErrLevel = 2 r.StageStatus = StatusExecFail @@ -903,16 +910,18 @@ func (s *session) execCommand(r *Record, socketFile string, commandName string, } if p.Percent < 100 || s.osc.OscPrintNone { - r.Buf.WriteString(buf.String()) + r.Buf.WriteString(allMessage) r.Buf.WriteString("\n") } - close(p.PanicAbort) + if p.Percent < 100 || s.myRecord.ErrLevel == 2 { + return fmt.Errorf(allMessage) + } + // 执行完成或中止后清理osc进程信息 // pl := s.sessionManager.ShowOscProcessList() // delete(pl, p.Sqlsha1) - - return true + return nil } func (s *session) mysqlAnalyzeOscOutput(out string, p *util.OscProcessInfo) { diff --git a/session/session_inception.go b/session/session_inception.go index d98f5a5c..842ce81a 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -3381,14 +3381,18 @@ func (s *session) checkAlterTable(node *ast.AlterTableStmt, sql string) { } } s.alterRollbackBuffer = nil - // if !s.hasError() && s.myRecord.useOsc { - // s.myRecord.ErrLevel = uint8(Max(int(s.myRecord.ErrLevel), 1)) - // if s.ghost.GhostOn { - // s.myRecord.Buf.WriteString("Will be executed using gh-ost.") - // } else { - // s.myRecord.Buf.WriteString("Will be executed using pt-osc.") - // } - // } + + if !s.hasError() && s.myRecord.useOsc && s.ghost.GhostOn && s.opt.Execute { + socketFile := s.getSocketFile(s.myRecord) + if _, err := os.Stat(socketFile); err == nil { + s.appendErrorMessage("listen unix socket file already in use") + return + } else if err != nil && !strings.Contains(err.Error(), "no such file or directory") { + log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) + s.appendErrorMessage(err.Error()) + return + } + } } func (s *session) checkAlterTableAlterColumn(t *TableInfo, c *ast.AlterTableSpec) { @@ -6080,6 +6084,10 @@ func (s *session) executeLocalOscKill(node *ast.ShowOscStmt) ([]sqlexec.RecordSe // s.sessionVars.StmtCtx.AppendWarning(errors.New("osc process has been aborted")) return nil, errors.New("osc process not aborted") } else { + if pi.Percent >= 100 { + return nil, errors.New("osc change has been completed") + } + if pi.SocketFile == "" { pi.PanicAbort <- util.ProcessOperationKill } else { @@ -6091,13 +6099,9 @@ func (s *session) executeLocalOscKill(node *ast.ShowOscStmt) ([]sqlexec.RecordSe } f.Close() // clean panic file - go func() { - timeTickerChan := time.Tick(time.Second * 10) - for { - <-timeTickerChan - os.Remove(panicFile) - } - }() + go time.AfterFunc(time.Second*10, func() { + os.Remove(panicFile) + }) } return nil, nil }