From a0825da2264a3e42406895cc920d48ed29e87de2 Mon Sep 17 00:00:00 2001 From: klarysz Date: Thu, 28 Jul 2022 20:41:45 +0200 Subject: [PATCH] patched type detection --- cmd/ast/param.go | 2 +- cmd/ast/testdata/case009/output.yaml | 2 +- cmd/columns.go | 28 ++++++++++++++++++++++++++-- view/connector.go | 13 +++++++++++-- view/pool.go | 5 +++++ view/sql.go | 15 ++++++++++----- 6 files changed, 54 insertions(+), 11 deletions(-) diff --git a/cmd/ast/param.go b/cmd/ast/param.go index 1fb1cb40..8d5a5b94 100644 --- a/cmd/ast/param.go +++ b/cmd/ast/param.go @@ -78,7 +78,7 @@ outer: } } } else { - typer = &ColumnType{ColumnName: text} + typer = &ColumnType{ColumnName: strings.ToLower(text)} } } diff --git a/cmd/ast/testdata/case009/output.yaml b/cmd/ast/testdata/case009/output.yaml index 13853fc5..414415c6 100644 --- a/cmd/ast/testdata/case009/output.yaml +++ b/cmd/ast/testdata/case009/output.yaml @@ -6,7 +6,7 @@ parameters: name: tID required: true typer: - columnname: T1.ID + columnname: t1.id source: |- SELECT *, ( SELECT abc diff --git a/cmd/columns.go b/cmd/columns.go index 55781c1b..f3b5df9f 100644 --- a/cmd/columns.go +++ b/cmd/columns.go @@ -5,6 +5,9 @@ import ( "database/sql" "fmt" "github.com/viant/sqlx/io" + rdata "github.com/viant/toolbox/data" + "github.com/viant/velty/ast/expr" + "github.com/viant/velty/parser" "strings" ) @@ -30,9 +33,29 @@ func (s *serverBuilder) updateTableColumnTypes(ctx context.Context, table *Table } func (s *serverBuilder) updatedColumns(table *Table, prefix, tableName string, db *sql.DB) { - SQL := "SELECT * FROM " + tableName + " WHERE 1 = 0" + parse, err := parser.Parse([]byte(tableName)) + var args []interface{} + expandMap := &rdata.Map{} + + if err == nil { + if anIndex := strings.Index(tableName, "SELECT"); anIndex != -1 { + + for _, statement := range parse.Stmt { + switch actual := statement.(type) { + case *expr.Select: + expandMap.SetValue(actual.FullName[1:], "?") + args = append(args, 0) + } + } + + tableName = expandMap.ExpandAsText(tableName) + } + } + + SQL := "SELECT * FROM " + tableName + " t WHERE 1 = 0" + fmt.Printf("checking %v ...\n", tableName) - query, err := db.QueryContext(context.Background(), SQL) + query, err := db.QueryContext(context.Background(), SQL, args...) if err != nil { s.logger.Write([]byte(fmt.Sprintf("error occured while updating table %v columns: %v", tableName, err))) return @@ -55,6 +78,7 @@ func (s *serverBuilder) updatedColumns(table *Table, prefix, tableName string, d if key != "" { key += "." } + key += column.Name() table.ColumnTypes[strings.ToLower(key)] = columnType } diff --git a/view/connector.go b/view/connector.go index aeab54ef..8c88ea9f 100644 --- a/view/connector.go +++ b/view/connector.go @@ -26,7 +26,7 @@ type ( //TODO add secure password storage db func() (*sql.DB, error) initialized bool - DBConfig + *DBConfig mux sync.Mutex } @@ -35,6 +35,7 @@ type ( ConnMaxIdleTimeMs int `json:",omitempty" yaml:",omitempty"` MaxOpenConns int `json:",omitempty" yaml:",omitempty"` ConnMaxLifetimeMs int `json:",omitempty" yaml:",omitempty"` + TimeoutTime int `json:",omitempty" yaml:",omitempty"` } ) @@ -67,6 +68,10 @@ func (c *Connector) Init(ctx context.Context, connectors Connectors) error { c.inherit(connector) } + if c.DBConfig == nil { + c.DBConfig = &DBConfig{} + } + if err := c.Validate(); err != nil { return err } @@ -98,7 +103,7 @@ func (c *Connector) DB(ctx context.Context) (*sql.DB, error) { } c.mux.Lock() - c.db = aDbPool.DB(ctx, c.Driver, dsn, &c.DBConfig) + c.db = aDbPool.DB(ctx, c.Driver, dsn, c.DBConfig) aDB, err := c.db() c.mux.Unlock() @@ -142,6 +147,10 @@ func (c *Connector) inherit(connector *Connector) { if c.Name == "" { c.Name = connector.Name } + + if c.DBConfig == nil { + c.DBConfig = connector.DBConfig + } } func (c *Connector) setDriverOptions(secret *scy.Secret) { diff --git a/view/pool.go b/view/pool.go index 8fcde855..e7377dc9 100644 --- a/view/pool.go +++ b/view/pool.go @@ -170,6 +170,11 @@ func (d *db) ctxWithTimeout(duration time.Duration) context.Context { func (p *dbPool) DB(ctx context.Context, driver, dsn string, config *DBConfig) func() (*sql.DB, error) { builder := &strings.Builder{} + + if config == nil { + config = &DBConfig{} + } + builder.WriteString(strconv.Itoa(config.ConnMaxLifetimeMs)) builder.WriteByte('#') builder.WriteString(strconv.Itoa(config.MaxIdleConns)) diff --git a/view/sql.go b/view/sql.go index 5fc60a65..051c7e22 100644 --- a/view/sql.go +++ b/view/sql.go @@ -155,11 +155,7 @@ func detectColumnsSQL(source string, v *View) (string, []interface{}, error) { SQL := sb.String() if source != v.Name && source != v.Table { - discover := metadata.EnrichWithDiscover(source, false) - replacement := rdata.Map{} - replacement.Put(keywords.AndCriteria[1:], "\n\n AND 1=0 ") - replacement.Put(keywords.WhereCriteria[1:], "\n\n WHERE 1=0 ") - SQL = replacement.ExpandAsText(discover) + SQL = ExpandWithFalseCondition(source) } var placeholders []interface{} @@ -172,6 +168,15 @@ func detectColumnsSQL(source string, v *View) (string, []interface{}, error) { return SQL, placeholders, nil } +func ExpandWithFalseCondition(source string) string { + discover := metadata.EnrichWithDiscover(source, false) + replacement := rdata.Map{} + replacement.Put(keywords.AndCriteria[1:], "\n\n AND 1=0 ") + replacement.Put(keywords.WhereCriteria[1:], "\n\n WHERE 1=0 ") + SQL := replacement.ExpandAsText(discover) + return SQL +} + func expandWithZeroValues(SQL string, template *Template) (string, error) { expandMap := rdata.Map{} for _, parameter := range template.Parameters {