diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index f7bbfa998c38..6c637ce0bde5 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -182,40 +182,37 @@ func journalProgress(db ethdb.KeyValueWriter, marker []byte, stats *generatorSta
 // The iteration start point will be assigned if the iterator is restored from
 // the last interruption. Max will be assigned in order to limit the maximum
 // amount of data involved in each iteration.
-func (dl *diskLayer) proveRange(root common.Hash, tr *trie.SecureTrie, prefix []byte, kind string, origin []byte, max int, onValue func([]byte) ([]byte, error)) ([][]byte, [][]byte, []byte, bool, error) {
+func (dl *diskLayer) proveRange(root common.Hash, tr *trie.SecureTrie, prefix []byte, kind string, origin []byte, max int, valueConvertFn func([]byte) ([]byte, error)) ([][]byte, [][]byte, []byte, bool, error) {
 	var (
 		keys  [][]byte
 		vals  [][]byte
-		count int
-		last  []byte
 		proof = rawdb.NewMemoryDatabase()
 	)
-
 	iter := dl.diskdb.NewIterator(prefix, origin)
 	defer iter.Release()
 
-	for iter.Next() && count < max {
+	for iter.Next() && len(keys) < max {
 		key := iter.Key()
 		if len(key) != len(prefix)+common.HashLength {
 			continue
 		}
-		if !bytes.HasPrefix(key, prefix) {
-			continue
-		}
-		last = common.CopyBytes(key[len(prefix):])
 		keys = append(keys, common.CopyBytes(key[len(prefix):]))
 
-		if onValue == nil {
+		if valueConvertFn == nil {
 			vals = append(vals, common.CopyBytes(iter.Value()))
-		} else {
-			val, err := onValue(common.CopyBytes(iter.Value()))
-			if err != nil {
-				log.Debug("Failed to convert the flat state", "kind", kind, "key", common.BytesToHash(key[len(prefix):]), "error", err)
-				return nil, nil, last, false, err
-			}
-			vals = append(vals, val)
+			continue
 		}
-		count += 1
+		val, err := valueConvertFn(iter.Value())
+		if err != nil {
+			log.Debug("Failed to convert the flat state", "kind", kind, "key", common.BytesToHash(key[len(prefix):]), "error", err)
+			return nil, nil, keys[len(keys)-1], false, err
+		}
+		vals = append(vals, val)
+	}
+	// Find out the key of last iterated element.
+	var last []byte
+	if len(keys) > 0 {
+		last = keys[len(keys)-1]
 	}
 	// Generate the Merkle proofs for the first and last element
 	if origin == nil {
@@ -239,7 +236,7 @@ func (dl *diskLayer) proveRange(root common.Hash, tr *trie.SecureTrie, prefix []
 	}
 	// Range prover says the trie still has some elements on the right side but
 	// the database is exhausted, then data loss is detected.
-	if cont && count < max {
+	if cont && len(keys) < max {
 		return nil, nil, last, false, errors.New("data loss in the state range")
 	}
 	return keys, vals, last, !cont, nil
@@ -248,14 +245,14 @@ func (dl *diskLayer) proveRange(root common.Hash, tr *trie.SecureTrie, prefix []
 // genRange generates the state segment with particular prefix. Generation can
 // either verify the correctness of existing state through rangeproof and skip
 // generation, or iterate trie to regenerate state on demand.
-func (dl *diskLayer) genRange(root common.Hash, prefix []byte, kind string, origin []byte, max int, stats *generatorStats, onState func(key []byte, val []byte, regen bool) error, onValue func([]byte) ([]byte, error)) (bool, []byte, error) {
+func (dl *diskLayer) genRange(root common.Hash, prefix []byte, kind string, origin []byte, max int, stats *generatorStats, onState func(key []byte, val []byte, regen bool) error, valueConvertFn func([]byte) ([]byte, error)) (bool, []byte, error) {
 	tr, err := trie.NewSecure(root, dl.triedb)
 	if err != nil {
 		stats.Log("Trie missing, state snapshotting paused", root, dl.genMarker)
 		return false, nil, errors.New("trie is missing")
 	}
 	// Use range prover to check the validity of the flat state in the range
-	keys, vals, last, exhausted, err := dl.proveRange(root, tr, prefix, kind, origin, max, onValue)
+	keys, vals, last, exhausted, err := dl.proveRange(root, tr, prefix, kind, origin, max, valueConvertFn)
 	if err == nil {
 		snapSuccessfulRangeProofMeter.Mark(1)
 		log.Debug("Proved state range", "kind", kind, "prefix", prefix, "origin", origin, "last", last)
@@ -427,6 +424,17 @@ func (dl *diskLayer) generate(stats *generatorStats) {
 					return nil // special case, the last is 0xffffffff...fff
 				}
 			}
+		} else {
+			// If the root is empty, we still need to ensure that any previous snapshot
+			// storage values are cleared
+			// TODO: investigate if this can be avoided, this will be very costly since it
+			// affects every single EOA account
+			//  - Perhaps we can avoid if where codeHash is emptyCode
+			prefix := append(rawdb.SnapshotStoragePrefix, accountHash.Bytes()...)
+			keyLen := len(rawdb.SnapshotStoragePrefix) + 2*common.HashLength
+			if err := wipeKeyRange(dl.diskdb, "storage", prefix, nil, nil, keyLen, snapWipedStorageMeter, false); err != nil {
+				return err
+			}
 		}
 		// Some account processed, unmark the marker
 		accMarker = nil
diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go
index 1a1ea06b5d76..ccf817cfe930 100644
--- a/core/state/snapshot/generate_test.go
+++ b/core/state/snapshot/generate_test.go
@@ -23,6 +23,7 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/ethdb/memorydb"
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/trie"
@@ -56,10 +57,13 @@ func TestGeneration(t *testing.T) {
 	acc = &Account{Balance: big.NewInt(3), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()}
 	val, _ = rlp.EncodeToBytes(acc)
 	accTrie.Update([]byte("acc-3"), val) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2
-	accTrie.Commit(nil)                  // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd
-	triedb.Commit(common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), false, nil)
+	root, _ := accTrie.Commit(nil)       // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd
+	triedb.Commit(root, false, nil)
 
-	snap := generateSnapshot(diskdb, triedb, 16, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"))
+	if have, want := root, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"); have != want {
+		t.Fatalf("have %#x want %#x", have, want)
+	}
+	snap := generateSnapshot(diskdb, triedb, 16, root)
 	select {
 	case <-snap.genPending:
 		// Snapshot generation succeeded
@@ -67,6 +71,7 @@ func TestGeneration(t *testing.T) {
 	case <-time.After(250 * time.Millisecond):
 		t.Errorf("Snapshot generation failed")
 	}
+	checkSnapRoot(t, snap, root)
 	// Signal abortion to the generator and wait for it to tear down
 	stop := make(chan *generatorStats)
 	snap.genAbort <- stop
@@ -120,10 +125,106 @@ func TestGenerateExistentState(t *testing.T) {
 	rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-3")), hashData([]byte("key-2")), []byte("val-2"))
 	rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-3")), hashData([]byte("key-3")), []byte("val-3"))
 
-	accTrie.Commit(nil) // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd
-	triedb.Commit(common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), false, nil)
+	root, _ := accTrie.Commit(nil) // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd
+	triedb.Commit(root, false, nil)
 
-	snap := generateSnapshot(diskdb, triedb, 16, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"))
+	snap := generateSnapshot(diskdb, triedb, 16, root)
+	select {
+	case <-snap.genPending:
+		// Snapshot generation succeeded
+
+	case <-time.After(250 * time.Millisecond):
+		t.Errorf("Snapshot generation failed")
+	}
+	checkSnapRoot(t, snap, root)
+	// Signal abortion to the generator and wait for it to tear down
+	stop := make(chan *generatorStats)
+	snap.genAbort <- stop
+	<-stop
+}
+
+func checkSnapRoot(t *testing.T, snap *diskLayer, trieRoot common.Hash) {
+	t.Helper()
+	accIt := snap.AccountIterator(common.Hash{})
+	defer accIt.Release()
+	snapRoot, err := generateTrieRoot(nil, accIt, common.Hash{}, stackTrieGenerate,
+		func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) {
+			storageIt, _ := snap.StorageIterator(accountHash, common.Hash{})
+			defer storageIt.Release()
+
+			hash, err := generateTrieRoot(nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false)
+			if err != nil {
+				return common.Hash{}, err
+			}
+			return hash, nil
+		}, newGenerateStats(), true)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	if snapRoot != trieRoot {
+		t.Fatalf("snaproot: %#x != trieroot #%x", snapRoot, trieRoot)
+	}
+}
+
+// Tests that snapshot generation with existent flat state, where the flat state contains
+// some errors
+func TestGenerateExistentStateWithExtraStorage(t *testing.T) {
+	//log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stderr, log.TerminalFormat(true))))
+
+	// We can't use statedb to make a test trie (circular dependency), so make
+	// a fake one manually. We're going with a small account trie of 3 accounts,
+	// two of which also has the same 3-slot storage trie attached.
+	var (
+		diskdb = memorydb.New()
+		triedb = trie.NewDatabase(diskdb)
+	)
+	stTrie, _ := trie.NewSecure(common.Hash{}, triedb)
+	stTrie.Update([]byte("key-1"), []byte("val-1"))
+	stTrie.Update([]byte("key-2"), []byte("val-2"))
+	stTrie.Update([]byte("key-3"), []byte("val-3"))
+	stTrie.Commit(nil)
+
+	accTrie, _ := trie.NewSecure(common.Hash{}, triedb)
+
+	{ // Account one
+		acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()}
+		val, _ := rlp.EncodeToBytes(acc)
+		accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
+		rawdb.WriteAccountSnapshot(diskdb, hashData([]byte("acc-1")), val)
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-1")), hashData([]byte("key-1")), []byte("val-1"))
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-1")), hashData([]byte("key-2")), []byte("val-2"))
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-1")), hashData([]byte("key-3")), []byte("val-3"))
+	}
+	{ // Account two
+		// The storage root is emptyHash, but the flat db has some storage values. This can happen
+		// if the storage was unset during sync
+		acc := &Account{Balance: big.NewInt(2), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()}
+		val, _ := rlp.EncodeToBytes(acc)
+		accTrie.Update([]byte("acc-2"), val) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7
+		diskdb.Put(hashData([]byte("acc-2")).Bytes(), val)
+		rawdb.WriteAccountSnapshot(diskdb, hashData([]byte("acc-2")), val)
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-2")), hashData([]byte("key-1")), []byte("val-1"))
+	}
+
+	{ // Account three
+		// This account changed codehash
+		acc := &Account{Balance: big.NewInt(3), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()}
+		val, _ := rlp.EncodeToBytes(acc)
+		accTrie.Update([]byte("acc-3"), val) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2
+		acc.CodeHash = hashData([]byte("codez")).Bytes()
+		val, _ = rlp.EncodeToBytes(acc)
+		rawdb.WriteAccountSnapshot(diskdb, hashData([]byte("acc-3")), val)
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-3")), hashData([]byte("key-1")), []byte("val-1"))
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-3")), hashData([]byte("key-2")), []byte("val-2"))
+		rawdb.WriteStorageSnapshot(diskdb, hashData([]byte("acc-3")), hashData([]byte("key-3")), []byte("val-3"))
+	}
+
+	root, _ := accTrie.Commit(nil) // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd
+	t.Logf("Root: %#x\n", root)
+	triedb.Commit(root, false, nil)
+
+	snap := generateSnapshot(diskdb, triedb, 16, root)
 	select {
 	case <-snap.genPending:
 		// Snapshot generation succeeded
@@ -131,6 +232,7 @@ func TestGenerateExistentState(t *testing.T) {
 	case <-time.After(250 * time.Millisecond):
 		t.Errorf("Snapshot generation failed")
 	}
+	checkSnapRoot(t, snap, root)
 	// Signal abortion to the generator and wait for it to tear down
 	stop := make(chan *generatorStats)
 	snap.genAbort <- stop