From 04cdb9c6720720c9663cc10747a6835ae0101c3e Mon Sep 17 00:00:00 2001 From: recall704 Date: Fri, 2 Apr 2021 11:59:12 +0800 Subject: [PATCH] add conn session vars --- v4/export/config.go | 11 +++++++++++ v4/export/dump.go | 7 +++++++ v4/export/sql.go | 27 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/v4/export/config.go b/v4/export/config.go index 00d86803..520f28d9 100644 --- a/v4/export/config.go +++ b/v4/export/config.go @@ -67,6 +67,7 @@ const ( flagOutputFilenameTemplate = "output-filename-template" flagCompleteInsert = "complete-insert" flagParams = "params" + flagSessionParams = "session-params" flagReadTimeout = "read-timeout" flagTransactionalConsistency = "transactional-consistency" flagCompress = "compress" @@ -128,6 +129,7 @@ type Config struct { FileSize uint64 StatementSize uint64 SessionParams map[string]interface{} + ConnSessionParams map[string]interface{} Labels prometheus.Labels `json:"-"` Tables DatabaseTables } @@ -164,6 +166,7 @@ func DefaultConfig() *Config { TableFilter: allFilter, DumpEmptyDatabase: true, SessionParams: make(map[string]interface{}), + ConnSessionParams: make(map[string]interface{}), OutputFileTemplate: DefaultOutputFileTemplate, PosAfterConnect: false, } @@ -240,6 +243,7 @@ func (conf *Config) DefineFlags(flags *pflag.FlagSet) { flags.String(flagOutputFilenameTemplate, "", "The output filename template (without file extension)") flags.Bool(flagCompleteInsert, false, "Use complete INSERT statements that include column names") flags.StringToString(flagParams, nil, `Extra session variables used while dumping, accepted format: --params "character_set_client=latin1,character_set_connection=latin1"`) + flags.StringToString(flagSessionParams, nil, `Extra session variables for dumping connection, accepted format: --session-params "net_read_timeout=86400,interactive_timeout=28800,wait_timeout=2147483,net_write_timeout=86400"`) flags.Bool(FlagHelp, false, "Print help message and quit") flags.Duration(flagReadTimeout, 15*time.Minute, "I/O read timeout for db connection.") _ = flags.MarkHidden(flagReadTimeout) @@ -428,6 +432,13 @@ func (conf *Config) ParseFromFlags(flags *pflag.FlagSet) error { if err != nil { return errors.Trace(err) } + connSessionParams, err := flags.GetStringToString(flagSessionParams) + if err != nil { + return errors.Trace(err) + } + for k, v := range connSessionParams { + conf.ConnSessionParams[k] = v + } conf.TableFilter, err = ParseTableFilter(tablesList, filters) if err != nil { diff --git a/v4/export/dump.go b/v4/export/dump.go index 6eb56707..d2f08519 100755 --- a/v4/export/dump.go +++ b/v4/export/dump.go @@ -120,6 +120,9 @@ func (d *Dumper) Dump() (dumpErr error) { return err } defer metaConn.Close() + // set conn session timeout + setConnSessionVariables(tctx, metaConn, conf.ConnSessionParams) + m.recordStartTime(time.Now()) // for consistency lock, we can write snapshot info after all tables are locked. // the binlog pos may changed because there is still possible write between we lock tables and write master status. @@ -152,6 +155,8 @@ func (d *Dumper) Dump() (dumpErr error) { return conn, errors.Trace(err1) } conn = newConn + // set conn session timeout + setConnSessionVariables(tctx, metaConn, conf.ConnSessionParams) // renew the master status after connection. dm can't close safe-mode until dm reaches current pos if conf.PosAfterConnect { err1 = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType, true) @@ -246,6 +251,8 @@ func (d *Dumper) startWriters(tctx *tcontext.Context, wg *errgroup.Group, taskCh if err != nil { return nil, func() {}, err } + // set conn session timeout + setConnSessionVariables(tctx, conn, conf.ConnSessionParams) writer := NewWriter(tctx, int64(i), conf, conn, d.extStore) writer.rebuildConnFn = rebuildConnFn writer.setFinishTableCallBack(func(task Task) { diff --git a/v4/export/sql.go b/v4/export/sql.go index 78c86526..24718e51 100644 --- a/v4/export/sql.go +++ b/v4/export/sql.go @@ -982,3 +982,30 @@ func buildWhereCondition(conf *Config, where string) string { func escapeString(s string) string { return strings.ReplaceAll(s, "`", "``") } + +func setConnSessionVariables(tctx *tcontext.Context, conn *sql.Conn, params map[string]interface{}) { + for k, v := range params { + var s string + if str, ok := v.(string); ok { + if _, err := strconv.ParseInt(str, 10, 64); err == nil { + // try to parse int + s = fmt.Sprintf("%v", v) + } else if _, err := strconv.ParseFloat(str, 64); err == nil { + // try to parse float + s = fmt.Sprintf("%v", v) + } else { + s = wrapStringWith(str, "'") + } + } else { + s = fmt.Sprintf("%v", v) + } + query := fmt.Sprintf("SET SESSION %s = %s", k, s) + _, err := conn.ExecContext(tctx, query) + if err != nil { + tctx.L().Warn("fail to set conn session vars", + zap.String("query", query), + zap.Error(err), + ) + } + } +}