Skip to content

Commit ed15e82

Browse files
committed
runtime: panic on uncomparable map key, even if map is empty
Reorg map flags a bit so we don't need any extra space for the extra flag. Fixes #23734 Change-Id: I436812156240ae90de53d0943fe1aabf3ea37417 Reviewed-on: https://go-review.googlesource.com/c/155918 Run-TryBot: Keith Randall <khr@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Ian Lance Taylor <iant@golang.org>
1 parent 14bdcc7 commit ed15e82

File tree

5 files changed

+172
-60
lines changed

5 files changed

+172
-60
lines changed

src/cmd/compile/internal/gc/reflect.go

+37-7
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,28 @@ func needkeyupdate(t *types.Type) bool {
10951095
}
10961096
}
10971097

1098+
// hashMightPanic reports whether the hash of a map key of type t might panic.
1099+
func hashMightPanic(t *types.Type) bool {
1100+
switch t.Etype {
1101+
case TINTER:
1102+
return true
1103+
1104+
case TARRAY:
1105+
return hashMightPanic(t.Elem())
1106+
1107+
case TSTRUCT:
1108+
for _, t1 := range t.Fields().Slice() {
1109+
if hashMightPanic(t1.Type) {
1110+
return true
1111+
}
1112+
}
1113+
return false
1114+
1115+
default:
1116+
return false
1117+
}
1118+
}
1119+
10981120
// formalType replaces byte and rune aliases with real types.
10991121
// They've been separate internally to make error messages
11001122
// better, but we have to merge them in the reflect tables.
@@ -1257,25 +1279,33 @@ func dtypesym(t *types.Type) *obj.LSym {
12571279
ot = dsymptr(lsym, ot, s1, 0)
12581280
ot = dsymptr(lsym, ot, s2, 0)
12591281
ot = dsymptr(lsym, ot, s3, 0)
1282+
var flags uint32
1283+
// Note: flags must match maptype accessors in ../../../../runtime/type.go
1284+
// and maptype builder in ../../../../reflect/type.go:MapOf.
12601285
if t.Key().Width > MAXKEYSIZE {
12611286
ot = duint8(lsym, ot, uint8(Widthptr))
1262-
ot = duint8(lsym, ot, 1) // indirect
1287+
flags |= 1 // indirect key
12631288
} else {
12641289
ot = duint8(lsym, ot, uint8(t.Key().Width))
1265-
ot = duint8(lsym, ot, 0) // not indirect
12661290
}
12671291

12681292
if t.Elem().Width > MAXVALSIZE {
12691293
ot = duint8(lsym, ot, uint8(Widthptr))
1270-
ot = duint8(lsym, ot, 1) // indirect
1294+
flags |= 2 // indirect value
12711295
} else {
12721296
ot = duint8(lsym, ot, uint8(t.Elem().Width))
1273-
ot = duint8(lsym, ot, 0) // not indirect
12741297
}
1275-
12761298
ot = duint16(lsym, ot, uint16(bmap(t).Width))
1277-
ot = duint8(lsym, ot, uint8(obj.Bool2int(isreflexive(t.Key()))))
1278-
ot = duint8(lsym, ot, uint8(obj.Bool2int(needkeyupdate(t.Key()))))
1299+
if isreflexive(t.Key()) {
1300+
flags |= 4 // reflexive key
1301+
}
1302+
if needkeyupdate(t.Key()) {
1303+
flags |= 8 // need key update
1304+
}
1305+
if hashMightPanic(t.Key()) {
1306+
flags |= 16 // hash might panic
1307+
}
1308+
ot = duint32(lsym, ot, flags)
12791309
ot = dextratype(lsym, ot, t, 0)
12801310

12811311
case TPTR:

src/reflect/type.go

+42-16
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,13 @@ type interfaceType struct {
394394
// mapType represents a map type.
395395
type mapType struct {
396396
rtype
397-
key *rtype // map key type
398-
elem *rtype // map element (value) type
399-
bucket *rtype // internal bucket structure
400-
keysize uint8 // size of key slot
401-
indirectkey uint8 // store ptr to key instead of key itself
402-
valuesize uint8 // size of value slot
403-
indirectvalue uint8 // store ptr to value instead of value itself
404-
bucketsize uint16 // size of bucket
405-
reflexivekey bool // true if k==k for all keys
406-
needkeyupdate bool // true if we need to update key on an overwrite
397+
key *rtype // map key type
398+
elem *rtype // map element (value) type
399+
bucket *rtype // internal bucket structure
400+
keysize uint8 // size of key slot
401+
valuesize uint8 // size of value slot
402+
bucketsize uint16 // size of bucket
403+
flags uint32
407404
}
408405

409406
// ptrType represents a pointer type.
@@ -1859,6 +1856,8 @@ func MapOf(key, elem Type) Type {
18591856
}
18601857

18611858
// Make a map type.
1859+
// Note: flag values must match those used in the TMAP case
1860+
// in ../cmd/compile/internal/gc/reflect.go:dtypesym.
18621861
var imap interface{} = (map[unsafe.Pointer]unsafe.Pointer)(nil)
18631862
mt := **(**mapType)(unsafe.Pointer(&imap))
18641863
mt.str = resolveReflectName(newName(s, "", false))
@@ -1867,23 +1866,29 @@ func MapOf(key, elem Type) Type {
18671866
mt.key = ktyp
18681867
mt.elem = etyp
18691868
mt.bucket = bucketOf(ktyp, etyp)
1869+
mt.flags = 0
18701870
if ktyp.size > maxKeySize {
18711871
mt.keysize = uint8(ptrSize)
1872-
mt.indirectkey = 1
1872+
mt.flags |= 1 // indirect key
18731873
} else {
18741874
mt.keysize = uint8(ktyp.size)
1875-
mt.indirectkey = 0
18761875
}
18771876
if etyp.size > maxValSize {
18781877
mt.valuesize = uint8(ptrSize)
1879-
mt.indirectvalue = 1
1878+
mt.flags |= 2 // indirect value
18801879
} else {
18811880
mt.valuesize = uint8(etyp.size)
1882-
mt.indirectvalue = 0
18831881
}
18841882
mt.bucketsize = uint16(mt.bucket.size)
1885-
mt.reflexivekey = isReflexive(ktyp)
1886-
mt.needkeyupdate = needKeyUpdate(ktyp)
1883+
if isReflexive(ktyp) {
1884+
mt.flags |= 4
1885+
}
1886+
if needKeyUpdate(ktyp) {
1887+
mt.flags |= 8
1888+
}
1889+
if hashMightPanic(ktyp) {
1890+
mt.flags |= 16
1891+
}
18871892
mt.ptrToThis = 0
18881893

18891894
ti, _ := lookupCache.LoadOrStore(ckey, &mt.rtype)
@@ -2122,6 +2127,27 @@ func needKeyUpdate(t *rtype) bool {
21222127
}
21232128
}
21242129

2130+
// hashMightPanic reports whether the hash of a map key of type t might panic.
2131+
func hashMightPanic(t *rtype) bool {
2132+
switch t.Kind() {
2133+
case Interface:
2134+
return true
2135+
case Array:
2136+
tt := (*arrayType)(unsafe.Pointer(t))
2137+
return hashMightPanic(tt.elem)
2138+
case Struct:
2139+
tt := (*structType)(unsafe.Pointer(t))
2140+
for _, f := range tt.fields {
2141+
if hashMightPanic(f.typ) {
2142+
return true
2143+
}
2144+
}
2145+
return false
2146+
default:
2147+
return false
2148+
}
2149+
}
2150+
21252151
// Make sure these routines stay in sync with ../../runtime/map.go!
21262152
// These types exist only for GC, so we only fill out GC relevant info.
21272153
// Currently, that's just size and the GC program. We also fill in string

src/runtime/map.go

+35-26
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ func mapaccess1(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
404404
msanread(key, t.key.size)
405405
}
406406
if h == nil || h.count == 0 {
407+
if t.hashMightPanic() {
408+
t.key.alg.hash(key, 0) // see issue 23734
409+
}
407410
return unsafe.Pointer(&zeroVal[0])
408411
}
409412
if h.flags&hashWriting != 0 {
@@ -434,12 +437,12 @@ bucketloop:
434437
continue
435438
}
436439
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
437-
if t.indirectkey {
440+
if t.indirectkey() {
438441
k = *((*unsafe.Pointer)(k))
439442
}
440443
if alg.equal(key, k) {
441444
v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
442-
if t.indirectvalue {
445+
if t.indirectvalue() {
443446
v = *((*unsafe.Pointer)(v))
444447
}
445448
return v
@@ -460,6 +463,9 @@ func mapaccess2(t *maptype, h *hmap, key unsafe.Pointer) (unsafe.Pointer, bool)
460463
msanread(key, t.key.size)
461464
}
462465
if h == nil || h.count == 0 {
466+
if t.hashMightPanic() {
467+
t.key.alg.hash(key, 0) // see issue 23734
468+
}
463469
return unsafe.Pointer(&zeroVal[0]), false
464470
}
465471
if h.flags&hashWriting != 0 {
@@ -490,12 +496,12 @@ bucketloop:
490496
continue
491497
}
492498
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
493-
if t.indirectkey {
499+
if t.indirectkey() {
494500
k = *((*unsafe.Pointer)(k))
495501
}
496502
if alg.equal(key, k) {
497503
v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
498-
if t.indirectvalue {
504+
if t.indirectvalue() {
499505
v = *((*unsafe.Pointer)(v))
500506
}
501507
return v, true
@@ -535,12 +541,12 @@ bucketloop:
535541
continue
536542
}
537543
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
538-
if t.indirectkey {
544+
if t.indirectkey() {
539545
k = *((*unsafe.Pointer)(k))
540546
}
541547
if alg.equal(key, k) {
542548
v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
543-
if t.indirectvalue {
549+
if t.indirectvalue() {
544550
v = *((*unsafe.Pointer)(v))
545551
}
546552
return k, v
@@ -620,14 +626,14 @@ bucketloop:
620626
continue
621627
}
622628
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
623-
if t.indirectkey {
629+
if t.indirectkey() {
624630
k = *((*unsafe.Pointer)(k))
625631
}
626632
if !alg.equal(key, k) {
627633
continue
628634
}
629635
// already have a mapping for key. Update it.
630-
if t.needkeyupdate {
636+
if t.needkeyupdate() {
631637
typedmemmove(t.key, k, key)
632638
}
633639
val = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
@@ -658,12 +664,12 @@ bucketloop:
658664
}
659665

660666
// store new key/value at insert position
661-
if t.indirectkey {
667+
if t.indirectkey() {
662668
kmem := newobject(t.key)
663669
*(*unsafe.Pointer)(insertk) = kmem
664670
insertk = kmem
665671
}
666-
if t.indirectvalue {
672+
if t.indirectvalue() {
667673
vmem := newobject(t.elem)
668674
*(*unsafe.Pointer)(val) = vmem
669675
}
@@ -676,7 +682,7 @@ done:
676682
throw("concurrent map writes")
677683
}
678684
h.flags &^= hashWriting
679-
if t.indirectvalue {
685+
if t.indirectvalue() {
680686
val = *((*unsafe.Pointer)(val))
681687
}
682688
return val
@@ -693,6 +699,9 @@ func mapdelete(t *maptype, h *hmap, key unsafe.Pointer) {
693699
msanread(key, t.key.size)
694700
}
695701
if h == nil || h.count == 0 {
702+
if t.hashMightPanic() {
703+
t.key.alg.hash(key, 0) // see issue 23734
704+
}
696705
return
697706
}
698707
if h.flags&hashWriting != 0 {
@@ -724,20 +733,20 @@ search:
724733
}
725734
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
726735
k2 := k
727-
if t.indirectkey {
736+
if t.indirectkey() {
728737
k2 = *((*unsafe.Pointer)(k2))
729738
}
730739
if !alg.equal(key, k2) {
731740
continue
732741
}
733742
// Only clear key if there are pointers in it.
734-
if t.indirectkey {
743+
if t.indirectkey() {
735744
*(*unsafe.Pointer)(k) = nil
736745
} else if t.key.kind&kindNoPointers == 0 {
737746
memclrHasPointers(k, t.key.size)
738747
}
739748
v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.valuesize))
740-
if t.indirectvalue {
749+
if t.indirectvalue() {
741750
*(*unsafe.Pointer)(v) = nil
742751
} else if t.elem.kind&kindNoPointers == 0 {
743752
memclrHasPointers(v, t.elem.size)
@@ -897,7 +906,7 @@ next:
897906
continue
898907
}
899908
k := add(unsafe.Pointer(b), dataOffset+uintptr(offi)*uintptr(t.keysize))
900-
if t.indirectkey {
909+
if t.indirectkey() {
901910
k = *((*unsafe.Pointer)(k))
902911
}
903912
v := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+uintptr(offi)*uintptr(t.valuesize))
@@ -909,7 +918,7 @@ next:
909918
// through the oldbucket, skipping any keys that will go
910919
// to the other new bucket (each oldbucket expands to two
911920
// buckets during a grow).
912-
if t.reflexivekey || alg.equal(k, k) {
921+
if t.reflexivekey() || alg.equal(k, k) {
913922
// If the item in the oldbucket is not destined for
914923
// the current new bucket in the iteration, skip it.
915924
hash := alg.hash(k, uintptr(h.hash0))
@@ -930,13 +939,13 @@ next:
930939
}
931940
}
932941
if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) ||
933-
!(t.reflexivekey || alg.equal(k, k)) {
942+
!(t.reflexivekey() || alg.equal(k, k)) {
934943
// This is the golden data, we can return it.
935944
// OR
936945
// key!=key, so the entry can't be deleted or updated, so we can just return it.
937946
// That's lucky for us because when key!=key we can't look it up successfully.
938947
it.key = k
939-
if t.indirectvalue {
948+
if t.indirectvalue() {
940949
v = *((*unsafe.Pointer)(v))
941950
}
942951
it.value = v
@@ -1160,15 +1169,15 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
11601169
throw("bad map state")
11611170
}
11621171
k2 := k
1163-
if t.indirectkey {
1172+
if t.indirectkey() {
11641173
k2 = *((*unsafe.Pointer)(k2))
11651174
}
11661175
var useY uint8
11671176
if !h.sameSizeGrow() {
11681177
// Compute hash to make our evacuation decision (whether we need
11691178
// to send this key/value to bucket x or bucket y).
11701179
hash := t.key.alg.hash(k2, uintptr(h.hash0))
1171-
if h.flags&iterator != 0 && !t.reflexivekey && !t.key.alg.equal(k2, k2) {
1180+
if h.flags&iterator != 0 && !t.reflexivekey() && !t.key.alg.equal(k2, k2) {
11721181
// If key != key (NaNs), then the hash could be (and probably
11731182
// will be) entirely different from the old hash. Moreover,
11741183
// it isn't reproducible. Reproducibility is required in the
@@ -1203,12 +1212,12 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
12031212
dst.v = add(dst.k, bucketCnt*uintptr(t.keysize))
12041213
}
12051214
dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
1206-
if t.indirectkey {
1215+
if t.indirectkey() {
12071216
*(*unsafe.Pointer)(dst.k) = k2 // copy pointer
12081217
} else {
12091218
typedmemmove(t.key, dst.k, k) // copy value
12101219
}
1211-
if t.indirectvalue {
1220+
if t.indirectvalue() {
12121221
*(*unsafe.Pointer)(dst.v) = *(*unsafe.Pointer)(v)
12131222
} else {
12141223
typedmemmove(t.elem, dst.v, v)
@@ -1274,12 +1283,12 @@ func reflect_makemap(t *maptype, cap int) *hmap {
12741283
if !ismapkey(t.key) {
12751284
throw("runtime.reflect_makemap: unsupported map key type")
12761285
}
1277-
if t.key.size > maxKeySize && (!t.indirectkey || t.keysize != uint8(sys.PtrSize)) ||
1278-
t.key.size <= maxKeySize && (t.indirectkey || t.keysize != uint8(t.key.size)) {
1286+
if t.key.size > maxKeySize && (!t.indirectkey() || t.keysize != uint8(sys.PtrSize)) ||
1287+
t.key.size <= maxKeySize && (t.indirectkey() || t.keysize != uint8(t.key.size)) {
12791288
throw("key size wrong")
12801289
}
1281-
if t.elem.size > maxValueSize && (!t.indirectvalue || t.valuesize != uint8(sys.PtrSize)) ||
1282-
t.elem.size <= maxValueSize && (t.indirectvalue || t.valuesize != uint8(t.elem.size)) {
1290+
if t.elem.size > maxValueSize && (!t.indirectvalue() || t.valuesize != uint8(sys.PtrSize)) ||
1291+
t.elem.size <= maxValueSize && (t.indirectvalue() || t.valuesize != uint8(t.elem.size)) {
12831292
throw("value size wrong")
12841293
}
12851294
if t.key.align > bucketCnt {

0 commit comments

Comments
 (0)