Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SQL parsing for MVT provider #744

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions provider/postgis/postgis.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Provider struct {
const (
// We quote the field and table names to prevent colliding with postgres keywords.
stdSQL = `SELECT %[1]v FROM %[2]v WHERE "%[3]v" && ` + bboxToken
mvtSQL = `SELECT %[1]v FROM %[2]v`

// SQL to get the column names, without hitting the information_schema. Though it might be better to hit the information_schema.
fldsSQL = `SELECT * FROM %[1]v LIMIT 0;`
Expand Down Expand Up @@ -100,7 +101,7 @@ var isSelectQuery = regexp.MustCompile(`(?i)^((\s*)(--.*\n)?)*select`)
// !BBOX! - [Required] will be replaced with the bounding box of the tile before the query is sent to the database.
// !ZOOM! - [Optional] will be replaced with the "Z" (zoom) value of the requested tile.
//
func CreateProvider(config dict.Dicter) (*Provider, error) {
func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) {

host, err := config.String(ConfigKeyHost, nil)
if err != nil {
Expand Down Expand Up @@ -297,7 +298,7 @@ func CreateProvider(config dict.Dicter) (*Provider, error) {
// Tablename and Fields will be used to build the query.
// We need to do some work. We need to check to see Fields contains the geom and gid fields
// and if not add them to the list. If Fields list is empty/nil we will use '*' for the field list.
l.sql, err = genSQL(&l, p.pool, tblName, fields, true)
l.sql, err = genSQL(&l, p.pool, tblName, fields, true, providerType)
if err != nil {
return nil, fmt.Errorf("could not generate sql, for layer(%v): %v", lname, err)
}
Expand Down Expand Up @@ -420,9 +421,15 @@ func (p Provider) inspectLayerGeomType(l *Layer) error {
// https://github.com/go-spatial/tegola/issues/180
//
// case insensitive search

re := regexp.MustCompile(`(?i)ST_AsBinary`)
sql := re.ReplaceAllString(l.sql, "ST_GeometryType")

re = regexp.MustCompile(`(?i)(ST_AsMVTGeom\(.*\))`)
if re.MatchString(sql) {
sql = fmt.Sprintf("SELECT ST_GeometryType(%v) FROM (%v) as q", l.geomField, sql)
}

// we only need a single result set to sniff out the geometry type
sql = fmt.Sprintf("%v LIMIT 1", sql)

Expand Down Expand Up @@ -638,12 +645,21 @@ func (p Provider) MVTForLayers(ctx context.Context, tile provider.Tile, layers [

// ref: https://postgis.net/docs/ST_AsMVT.html
// bytea ST_AsMVT(anyelement row, text name, integer extent, text geom_name, text feature_id_name)

var featureIDName string

if l.IDFieldName() == "" {
featureIDName = "NULL"
} else {
featureIDName = fmt.Sprintf(`'%s'`, l.IDFieldName())
}

sqls = append(sqls, fmt.Sprintf(
`(SELECT ST_AsMVT(q,'%s',%d,'%s','%s') AS data FROM (%s) AS q)`,
`(SELECT ST_AsMVT(q,'%s',%d,'%s',%s) AS data FROM (%s) AS q)`,
layers[i].MVTName,
tegola.DefaultExtent,
l.GeomFieldName(),
l.IDFieldName(),
featureIDName,
sql,
))
}
Expand Down
8 changes: 6 additions & 2 deletions provider/postgis/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,9 @@ func init() {
// !BBOX! - [Required] will be replaced with the bounding box of the tile before the query is sent to the database.
// !ZOOM! - [Optional] will be replaced with the "Z" (zoom) value of the requested tile.
//
func NewTileProvider(config dict.Dicter) (provider.Tiler, error) { return CreateProvider(config) }
func NewMVTTileProvider(config dict.Dicter) (provider.MVTTiler, error) { return CreateProvider(config) }
func NewTileProvider(config dict.Dicter) (provider.Tiler, error) {
return CreateProvider(config, "postgis")
}
func NewMVTTileProvider(config dict.Dicter) (provider.MVTTiler, error) {
return CreateProvider(config, "mvt_postgis")
}
28 changes: 24 additions & 4 deletions provider/postgis/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ import (
"github.com/jackc/pgx/pgtype"
)

// isMVT will return true if the provider is MVT based
func isMVT(providerType string) bool {
return providerType == "mvt_postgis"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit, but since we're using "mvt_postgis" in more than 1 place (register.go and here), we should probably move it to a const.

}

// genSQL will fill in the SQL field of a layer given a pool, and list of fields.
func genSQL(l *Layer, pool *pgx.ConnPool, tblname string, flds []string, buffer bool) (sql string, err error) {
func genSQL(l *Layer, pool *pgx.ConnPool, tblname string, flds []string, buffer bool, providerType string) (sql string, err error) {

// we need to hit the database to see what the fields are.
if len(flds) == 0 {
Expand Down Expand Up @@ -61,10 +66,19 @@ func genSQL(l *Layer, pool *pgx.ConnPool, tblname string, flds []string, buffer

// to avoid field names possibly colliding with Postgres keywords,
// we wrap the field names in quotes

if fgeom == -1 {
flds = append(flds, fmt.Sprintf(`ST_AsBinary("%v") AS "%[1]v"`, l.geomField))
if isMVT(providerType) {
flds = append(flds, fmt.Sprintf(`"%v" AS "%[1]v"`, l.geomField))
} else {
flds = append(flds, fmt.Sprintf(`ST_AsBinary("%v") AS "%[1]v"`, l.geomField))
}
} else {
flds[fgeom] = fmt.Sprintf(`ST_AsBinary("%v") AS "%[1]v"`, l.geomField)
if isMVT(providerType) {
flds[fgeom] = fmt.Sprintf(`"%v" AS "%[1]v"`, l.geomField)
} else {
flds[fgeom] = fmt.Sprintf(`ST_AsBinary("%v") AS "%[1]v"`, l.geomField)
}
}

// add required id field
Expand All @@ -74,7 +88,13 @@ func genSQL(l *Layer, pool *pgx.ConnPool, tblname string, flds []string, buffer

selectClause := strings.Join(flds, ", ")

return fmt.Sprintf(stdSQL, selectClause, tblname, l.geomField), nil
sqlTmpl := stdSQL

if isMVT(providerType) {
sqlTmpl = mvtSQL
}

return fmt.Sprintf(sqlTmpl, selectClause, tblname, l.geomField), nil
}

const (
Expand Down