Skip to content

Commit

Permalink
sql: add casts for unknown and enum types to lookupCast
Browse files Browse the repository at this point in the history
Release note: None
  • Loading branch information
mgartner committed Jan 6, 2022
1 parent 3dc1fa6 commit 7d0f75a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
77 changes: 59 additions & 18 deletions pkg/sql/sem/tree/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1219,12 +1219,6 @@ func ValidCast(src, tgt *types.T, ctx CastContext) bool {
return true
}

// Unknown is the type given to an expression that statically evaluates to
// NULL. NULL can be cast to any type in any context.
if src.Oid() == oid.T_unknown {
return true
}

srcFamily := src.Family()
tgtFamily := tgt.Family()

Expand Down Expand Up @@ -1253,24 +1247,71 @@ func ValidCast(src, tgt *types.T, ctx CastContext) bool {

// If src and tgt are not both array or tuple types, check castMap for a
// valid cast.
c, ok := lookupCast(
src.Oid(),
tgt.Oid(),
false, /* intervalStyleEnabled */
false, /* dateStyleEnabled */
)
c, ok := lookupCast(src, tgt, false /* intervalStyleEnabled */, false /* dateStyleEnabled */)
if ok {
return c.maxContext >= ctx
}

return false
}

// lookupCast returns a cast that describes the cast from src to tgt if
// it exists. If it does not exist, ok=false is returned.
func lookupCast(src, tgt oid.Oid, intervalStyleEnabled, dateStyleEnabled bool) (cast, bool) {
if tgts, ok := castMap[src]; ok {
if c, ok := tgts[tgt]; ok {
// lookupCast returns a cast that describes the cast from src to tgt if it
// exists. If it does not exist, ok=false is returned.
func lookupCast(src, tgt *types.T, intervalStyleEnabled, dateStyleEnabled bool) (cast, bool) {
srcFamily := src.Family()
tgtFamily := tgt.Family()
srcFamily.Name()

// Unknown is the type given to an expression that statically evaluates
// to NULL. NULL can be immutably cast to any type in any context.
if srcFamily == types.UnknownFamily {
return cast{
maxContext: CastContextImplicit,
volatility: VolatilityImmutable,
}, true
}

// Enums have dynamic OIDs, so they can't be populated in castMap. Instead,
// we dynamically create cast structs for valid enum casts.
if srcFamily == types.EnumFamily && tgtFamily == types.StringFamily {
// Casts from enum types to strings are immutable and allowed in
// assignment contexts.
return cast{
maxContext: CastContextAssignment,
volatility: VolatilityImmutable,
}, true
}
if tgtFamily == types.EnumFamily {
switch srcFamily {
case types.StringFamily:
// Casts from string types to enums are immutable and allowed in
// explicit contexts.
return cast{
maxContext: CastContextExplicit,
volatility: VolatilityImmutable,
}, true
case types.UnknownFamily:
// Casts from unknown to enums are immutable and allowed in implicit
// contexts.
return cast{
maxContext: CastContextImplicit,
volatility: VolatilityImmutable,
}, true
case types.BytesFamily:
// Casts from byte types to enums are immutable and allowed in
// explicit contexts.
// TODO(mgartner): We may not want to support the cast from BYTES to
// ENUM because Postgres does not support it, and it's been the
// source of at least one minor bug (see #74316).
return cast{
maxContext: CastContextExplicit,
volatility: VolatilityImmutable,
}, true
}
}

if tgts, ok := castMap[src.Oid()]; ok {
if c, ok := tgts[tgt.Oid()]; ok {
if intervalStyleEnabled && c.intervalStyleAffected ||
dateStyleEnabled && c.dateStyleAffected {
c.volatility = VolatilityStable
Expand Down Expand Up @@ -1738,7 +1779,7 @@ func LookupCastVolatility(from, to *types.T, sd *sessiondata.SessionData) (_ Vol
}

// If the volatility has been set in castMap, return it.
c, ok := lookupCast(from.Oid(), to.Oid(), intervalStyleEnabled, dateStyleEnabled)
c, ok := lookupCast(from, to, intervalStyleEnabled, dateStyleEnabled)
if ok && c.volatility != volatilityTODO {
return c.volatility, true
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func resolveCast(
default:
var v Volatility
var hint string
c, ok := lookupCast(castFrom.Oid(), castTo.Oid(), intervalStyleEnabled, dateStyleEnabled)
c, ok := lookupCast(castFrom, castTo, intervalStyleEnabled, dateStyleEnabled)
if ok && c.volatility != volatilityTODO {
// If the volatility has been set in castMap, use it.
v = c.volatility
Expand Down

0 comments on commit 7d0f75a

Please sign in to comment.