diff --git a/packages/flutter/lib/src/widgets/focus_traversal.dart b/packages/flutter/lib/src/widgets/focus_traversal.dart
index 7911ee898d68..fdd94f371b30 100644
--- a/packages/flutter/lib/src/widgets/focus_traversal.dart
+++ b/packages/flutter/lib/src/widgets/focus_traversal.dart
@@ -82,9 +82,6 @@ enum TraversalDirection {
   /// This direction is unaffected by the [Directionality] of the current
   /// context.
   left,
-
-  // TODO(gspencer): Add diagonal traversal directions used by TV remotes and
-  // game controllers when we support them.
 }
 
 /// An object used to specify a focus traversal policy used for configuring a
@@ -547,6 +544,46 @@ mixin DirectionalFocusTraversalPolicyMixin on FocusTraversalPolicy {
     return null;
   }
 
+  static int _verticalCompare(Offset target, Offset a, Offset b) {
+    return (a.dy - target.dy).abs().compareTo((b.dy - target.dy).abs());
+  }
+
+  static int _horizontalCompare(Offset target, Offset a, Offset b) {
+    return (a.dx - target.dx).abs().compareTo((b.dx - target.dx).abs());
+  }
+
+  // Sort the ones that are closest to target vertically first, and if two are
+  // the same vertical distance, pick the one that is closest horizontally.
+  static Iterable<FocusNode> _sortByDistancePreferVertical(Offset target, Iterable<FocusNode> nodes) {
+    final List<FocusNode> sorted = nodes.toList();
+    mergeSort<FocusNode>(sorted, compare: (FocusNode nodeA, FocusNode nodeB) {
+      final Offset a = nodeA.rect.center;
+      final Offset b = nodeB.rect.center;
+      final int vertical = _verticalCompare(target, a, b);
+      if (vertical == 0) {
+        return _horizontalCompare(target, a, b);
+      }
+      return vertical;
+    });
+    return sorted;
+  }
+
+  // Sort the ones that are closest horizontally first, and if two are the same
+  // horizontal distance, pick the one that is closest vertically.
+  static Iterable<FocusNode> _sortByDistancePreferHorizontal(Offset target, Iterable<FocusNode> nodes) {
+    final List<FocusNode> sorted = nodes.toList();
+    mergeSort<FocusNode>(sorted, compare: (FocusNode nodeA, FocusNode nodeB) {
+      final Offset a = nodeA.rect.center;
+      final Offset b = nodeB.rect.center;
+      final int horizontal = _horizontalCompare(target, a, b);
+      if (horizontal == 0) {
+        return _verticalCompare(target, a, b);
+      }
+      return horizontal;
+    });
+    return sorted;
+  }
+
   // Sorts nodes from left to right horizontally, and removes nodes that are
   // either to the right of the left side of the target node if we're going
   // left, or to the left of the right side of the target node if we're going
@@ -555,52 +592,54 @@ mixin DirectionalFocusTraversalPolicyMixin on FocusTraversalPolicy {
   // This doesn't need to take into account directionality because it is
   // typically intending to actually go left or right, not in a reading
   // direction.
-  Iterable<FocusNode>? _sortAndFilterHorizontally(
+  Iterable<FocusNode> _sortAndFilterHorizontally(
     TraversalDirection direction,
     Rect target,
-    FocusNode nearestScope,
+    Iterable<FocusNode> nodes,
   ) {
     assert(direction == TraversalDirection.left || direction == TraversalDirection.right);
-    final Iterable<FocusNode> nodes = nearestScope.traversalDescendants;
-    assert(!nodes.contains(nearestScope));
-    final List<FocusNode> sorted = nodes.toList();
-    mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) => a.rect.center.dx.compareTo(b.rect.center.dx));
-    Iterable<FocusNode>? result;
+    final Iterable<FocusNode> filtered;
     switch (direction) {
       case TraversalDirection.left:
-        result = sorted.where((FocusNode node) => node.rect != target && node.rect.center.dx <= target.left);
+        filtered = nodes.where((FocusNode node) => node.rect != target && node.rect.center.dx <= target.left);
         break;
       case TraversalDirection.right:
-        result = sorted.where((FocusNode node) => node.rect != target && node.rect.center.dx >= target.right);
+        filtered = nodes.where((FocusNode node) => node.rect != target && node.rect.center.dx >= target.right);
         break;
       case TraversalDirection.up:
       case TraversalDirection.down:
-        break;
+        throw ArgumentError('Invalid direction $direction');
     }
-    return result;
+    final List<FocusNode> sorted = filtered.toList();
+    // Sort all nodes from left to right.
+    mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) => a.rect.center.dx.compareTo(b.rect.center.dx));
+    return sorted;
   }
 
   // Sorts nodes from top to bottom vertically, and removes nodes that are
   // either below the top of the target node if we're going up, or above the
   // bottom of the target node if we're going down.
-  Iterable<FocusNode>? _sortAndFilterVertically(
+  Iterable<FocusNode> _sortAndFilterVertically(
     TraversalDirection direction,
     Rect target,
     Iterable<FocusNode> nodes,
   ) {
-    final List<FocusNode> sorted = nodes.toList();
-    mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) => a.rect.center.dy.compareTo(b.rect.center.dy));
+    assert(direction == TraversalDirection.up || direction == TraversalDirection.down);
+    final Iterable<FocusNode> filtered;
     switch (direction) {
       case TraversalDirection.up:
-        return sorted.where((FocusNode node) => node.rect != target && node.rect.center.dy <= target.top);
+        filtered = nodes.where((FocusNode node) => node.rect != target && node.rect.center.dy <= target.top);
+        break;
       case TraversalDirection.down:
-        return sorted.where((FocusNode node) => node.rect != target && node.rect.center.dy >= target.bottom);
+        filtered = nodes.where((FocusNode node) => node.rect != target && node.rect.center.dy >= target.bottom);
+        break;
       case TraversalDirection.left:
       case TraversalDirection.right:
-        break;
+        throw ArgumentError('Invalid direction $direction');
     }
-    assert(direction == TraversalDirection.up || direction == TraversalDirection.down);
-    return null;
+    final List<FocusNode> sorted = filtered.toList();
+    mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) => a.rect.center.dy.compareTo(b.rect.center.dy));
+    return sorted;
   }
 
   // Updates the policy data to keep the previously visited node so that we can
@@ -745,71 +784,55 @@ mixin DirectionalFocusTraversalPolicyMixin on FocusTraversalPolicy {
     switch (direction) {
       case TraversalDirection.down:
       case TraversalDirection.up:
-        Iterable<FocusNode>? eligibleNodes = _sortAndFilterVertically(
-          direction,
-          focusedChild.rect,
-          nearestScope.traversalDescendants,
-        );
+        Iterable<FocusNode> eligibleNodes = _sortAndFilterVertically(direction, focusedChild.rect, nearestScope.traversalDescendants);
+        if (eligibleNodes.isEmpty) {
+          break;
+        }
         if (focusedScrollable != null && !focusedScrollable.position.atEdge) {
-          final Iterable<FocusNode> filteredEligibleNodes = eligibleNodes!.where((FocusNode node) => Scrollable.maybeOf(node.context!) == focusedScrollable);
+          final Iterable<FocusNode> filteredEligibleNodes = eligibleNodes.where((FocusNode node) => Scrollable.maybeOf(node.context!) == focusedScrollable);
           if (filteredEligibleNodes.isNotEmpty) {
             eligibleNodes = filteredEligibleNodes;
           }
         }
-        if (eligibleNodes!.isEmpty) {
-          break;
-        }
-        List<FocusNode> sorted = eligibleNodes.toList();
         if (direction == TraversalDirection.up) {
-          sorted = sorted.reversed.toList();
+          eligibleNodes = eligibleNodes.toList().reversed;
         }
         // Find any nodes that intersect the band of the focused child.
         final Rect band = Rect.fromLTRB(focusedChild.rect.left, -double.infinity, focusedChild.rect.right, double.infinity);
-        final Iterable<FocusNode> inBand = sorted.where((FocusNode node) => !node.rect.intersect(band).isEmpty);
+        final Iterable<FocusNode> inBand = eligibleNodes.where((FocusNode node) => !node.rect.intersect(band).isEmpty);
         if (inBand.isNotEmpty) {
-          // The inBand list is already sorted by horizontal distance, so pick
-          // the closest one.
-          found = inBand.first;
+          found = _sortByDistancePreferVertical(focusedChild.rect.center, inBand).first;
           break;
         }
-        // Only out-of-band targets remain, so pick the one that is closest the
-        // to the center line horizontally.
-        mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) {
-          return (a.rect.center.dx - focusedChild.rect.center.dx).abs().compareTo((b.rect.center.dx - focusedChild.rect.center.dx).abs());
-        });
-        found = sorted.first;
+        // Only out-of-band targets are eligible, so pick the one that is
+        // closest the to the center line horizontally.
+        found = _sortByDistancePreferHorizontal(focusedChild.rect.center, eligibleNodes).first;
         break;
       case TraversalDirection.right:
       case TraversalDirection.left:
-        Iterable<FocusNode>? eligibleNodes = _sortAndFilterHorizontally(direction, focusedChild.rect, nearestScope);
+        Iterable<FocusNode> eligibleNodes = _sortAndFilterHorizontally(direction, focusedChild.rect, nearestScope.traversalDescendants);
+        if (eligibleNodes.isEmpty) {
+          break;
+        }
         if (focusedScrollable != null && !focusedScrollable.position.atEdge) {
-          final Iterable<FocusNode> filteredEligibleNodes = eligibleNodes!.where((FocusNode node) => Scrollable.maybeOf(node.context!) == focusedScrollable);
+          final Iterable<FocusNode> filteredEligibleNodes = eligibleNodes.where((FocusNode node) => Scrollable.maybeOf(node.context!) == focusedScrollable);
           if (filteredEligibleNodes.isNotEmpty) {
             eligibleNodes = filteredEligibleNodes;
           }
         }
-        if (eligibleNodes!.isEmpty) {
-          break;
-        }
-        List<FocusNode> sorted = eligibleNodes.toList();
         if (direction == TraversalDirection.left) {
-          sorted = sorted.reversed.toList();
+          eligibleNodes = eligibleNodes.toList().reversed;
         }
         // Find any nodes that intersect the band of the focused child.
         final Rect band = Rect.fromLTRB(-double.infinity, focusedChild.rect.top, double.infinity, focusedChild.rect.bottom);
-        final Iterable<FocusNode> inBand = sorted.where((FocusNode node) => !node.rect.intersect(band).isEmpty);
+        final Iterable<FocusNode> inBand = eligibleNodes.where((FocusNode node) => !node.rect.intersect(band).isEmpty);
         if (inBand.isNotEmpty) {
-          // The inBand list is already sorted by vertical distance, so pick the
-          // closest one.
-          found = inBand.first;
+          found = _sortByDistancePreferHorizontal(focusedChild.rect.center, inBand).first;
           break;
         }
-        // Only out-of-band targets remain, so pick the one that is closest the
+        // Only out-of-band targets are eligible, so pick the one that is
         // to the center line vertically.
-        mergeSort<FocusNode>(sorted, compare: (FocusNode a, FocusNode b) {
-          return (a.rect.center.dy - focusedChild.rect.center.dy).abs().compareTo((b.rect.center.dy - focusedChild.rect.center.dy).abs());
-        });
-        found = sorted.first;
+        found = _sortByDistancePreferVertical(focusedChild.rect.center, eligibleNodes).first;
         break;
     }
     if (found != null) {
@@ -892,8 +915,8 @@ class _ReadingOrderSortData with Diagnosticable {
     }
     if (common!.isEmpty) {
       // If there is no common ancestor, then arbitrarily pick the
-      // directionality of the first group, which is the equivalent of the "first
-      // strongly typed" item in a bidi algorithm.
+      // directionality of the first group, which is the equivalent of the
+      // "first strongly typed" item in a bidirectional algorithm.
       return list.first.directionality;
     }
     // Find the closest common ancestor. The memberAncestors list contains the
diff --git a/packages/flutter/test/widgets/focus_traversal_test.dart b/packages/flutter/test/widgets/focus_traversal_test.dart
index ff4c6e8054bf..394ff731d443 100644
--- a/packages/flutter/test/widgets/focus_traversal_test.dart
+++ b/packages/flutter/test/widgets/focus_traversal_test.dart
@@ -10,18 +10,6 @@ import 'package:flutter_test/flutter_test.dart';
 
 import 'semantics_tester.dart';
 
-/// Used to test removal of nodes while sorting.
-class SkipAllButFirstAndLastPolicy extends FocusTraversalPolicy with DirectionalFocusTraversalPolicyMixin {
-  @override
-  Iterable<FocusNode> sortDescendants(Iterable<FocusNode> descendants, FocusNode currentNode) {
-    return <FocusNode>[
-      descendants.first,
-      if (currentNode != descendants.first && currentNode != descendants.last) currentNode,
-      descendants.last,
-    ];
-  }
-}
-
 void main() {
   group(WidgetOrderTraversalPolicy, () {
     testWidgets('Find the initial focus if there is none yet.', (WidgetTester tester) async {
@@ -1343,27 +1331,21 @@ void main() {
     });
 
     testWidgets('Directional focus avoids hysteresis.', (WidgetTester tester) async {
-      final List<GlobalKey> keys = <GlobalKey>[
-        GlobalKey(debugLabel: 'row 1:1'),
-        GlobalKey(debugLabel: 'row 2:1'),
-        GlobalKey(debugLabel: 'row 2:2'),
-        GlobalKey(debugLabel: 'row 3:1'),
-        GlobalKey(debugLabel: 'row 3:2'),
-        GlobalKey(debugLabel: 'row 3:3'),
-      ];
-      List<bool?> focus = List<bool?>.generate(keys.length, (int _) => null);
+      List<bool?> focus = List<bool?>.generate(6, (int _) => null);
+      final List<FocusNode> nodes = List<FocusNode>.generate(6, (int index) => FocusNode(debugLabel: 'Node $index'));
       Focus makeFocus(int index) {
         return Focus(
-          debugLabel: keys[index].toString(),
+          debugLabel: '[$index]',
+          focusNode: nodes[index],
           onFocusChange: (bool isFocused) => focus[index] = isFocused,
-          child: SizedBox(width: 100, height: 100, key: keys[index]),
+          child: const SizedBox(width: 100, height: 100),
         );
       }
 
       /// Layout is:
-      ///           keys[0]
-      ///       keys[1] keys[2]
-      ///    keys[3] keys[4] keys[5]
+      ///          [0]
+      ///       [1]   [2]
+      ///    [3]   [4]   [5]
       await tester.pumpWidget(
         Directionality(
           textDirection: TextDirection.ltr,
@@ -1402,80 +1384,203 @@ void main() {
       );
 
       void clear() {
-        focus = List<bool?>.generate(keys.length, (int _) => null);
+        focus = List<bool?>.generate(focus.length, (int _) => null);
       }
 
-      final List<FocusNode> nodes = keys.map<FocusNode>((GlobalKey key) => Focus.of(tester.element(find.byKey(key)))).toList();
       final FocusNode scope = nodes[0].enclosingScope!;
       nodes[4].requestFocus();
 
-      void expectState(List<bool?> states) {
-        for (int index = 0; index < states.length; ++index) {
-          expect(focus[index], states[index] == null ? isNull : (states[index]! ? isTrue : isFalse));
-          if (states[index] == null) {
-            expect(nodes[index].hasFocus, isFalse);
-          } else {
-            expect(nodes[index].hasFocus, states[index]);
-          }
-          expect(scope.hasFocus, isTrue);
-        }
-      }
-
       // Test to make sure that the same path is followed backwards and forwards.
       await tester.pump();
-      expectState(<bool?>[null, null, null, null, true, null]);
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, true, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.up), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[null, null, true, null, false, null]);
+      expect(focus, orderedEquals(<bool?>[null, null, true, null, false, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.up), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[true, null, false, null, null, null]);
+      expect(focus, orderedEquals(<bool?>[true, null, false, null, null, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.down), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[false, null, true, null, null, null]);
+      expect(focus, orderedEquals(<bool?>[false, null, true, null, null, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.down), isTrue);
       await tester.pump();
-      expectState(<bool?>[null, null, false, null, true, null]);
+      expect(focus, orderedEquals(<bool?>[null, null, false, null, true, null]));
       clear();
 
       // Make sure that moving in a different axis clears the history.
       expect(scope.focusInDirection(TraversalDirection.left), isTrue);
       await tester.pump();
-      expectState(<bool?>[null, null, null, true, false, null]);
+      expect(focus, orderedEquals(<bool?>[null, null, null, true, false, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.up), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[null, true, null, false, null, null]);
+      expect(focus, orderedEquals(<bool?>[null, true, null, false, null, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.up), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[true, false, null, null, null, null]);
+      expect(focus, orderedEquals(<bool?>[true, false, null, null, null, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.down), isTrue);
       await tester.pump();
 
-      expectState(<bool?>[false, true, null, null, null, null]);
+      expect(focus, orderedEquals(<bool?>[false, true, null, null, null, null]));
       clear();
 
       expect(scope.focusInDirection(TraversalDirection.down), isTrue);
       await tester.pump();
-      expectState(<bool?>[null, false, null, true, null, null]);
+      expect(focus, orderedEquals(<bool?>[null, false, null, true, null, null]));
+      clear();
+    });
+
+    testWidgets('Directional prefers the closest node even on irregular grids', (WidgetTester tester) async {
+      const int cols = 3;
+      const int rows = 3;
+      List<bool?> focus = List<bool?>.generate(rows * cols, (int _) => null);
+      final List<FocusNode> nodes = List<FocusNode>.generate(rows * cols, (int index) => FocusNode(debugLabel: 'Node $index'));
+
+      Widget makeFocus(int row, int col) {
+        final int index = row * rows + col;
+        return Focus(
+          focusNode: nodes[index],
+          onFocusChange: (bool isFocused) => focus[index] = isFocused,
+          child: Container(
+            // Make some of the items a different size to test the code that
+            // checks for the closest node.
+            width: index == 3 ? 150 : 100,
+            height: index == 1 ? 150 : 100,
+            color: Colors.primaries[index],
+            child: Text('[$row, $col]'),
+          ),
+        );
+      }
+
+      /// Layout is:
+      ///           [0, 1]
+      ///    [0, 0] [    ] [0, 2]
+      ///    [  1,  0 ] [1, 1] [1, 2]
+      ///    [2, 0] [2, 1] [2, 2]
+      await tester.pumpWidget(
+        Directionality(
+          textDirection: TextDirection.ltr,
+          child: FocusTraversalGroup(
+            policy: WidgetOrderTraversalPolicy(),
+            child: FocusScope(
+              debugLabel: 'Scope',
+              child: Column(
+                children: <Widget>[
+                  Row(
+                    mainAxisAlignment: MainAxisAlignment.center,
+                    crossAxisAlignment: CrossAxisAlignment.end,
+                    children: <Widget>[
+                      makeFocus(0, 0),
+                      makeFocus(0, 1),
+                      makeFocus(0, 2),
+                    ],
+                  ),
+                  Row(
+                    mainAxisAlignment: MainAxisAlignment.center,
+                    children: <Widget>[
+                      makeFocus(1, 0),
+                      makeFocus(1, 1),
+                      makeFocus(1, 2),
+                    ],
+                  ),
+                  Row(
+                    mainAxisAlignment: MainAxisAlignment.center,
+                    children: <Widget>[
+                      makeFocus(2, 0),
+                      makeFocus(2, 1),
+                      makeFocus(2, 2),
+                    ],
+                  ),
+                ],
+              ),
+            ),
+          ),
+        ),
+      );
+
+      void clear() {
+        focus = List<bool?>.generate(focus.length, (int _) => null);
+      }
+
+      final FocusNode scope = nodes[0].enclosingScope!;
+
+      // Go down the center column and make sure that the focus stays in that
+      // column, even though the second row is irregular.
+      nodes[1].requestFocus();
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, true, null, null, null, null, null, null, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.down), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, false, null, null, true, null, null, null, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.down), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, false, null, null, true, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.down), isFalse);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, null, null, null, null, null]));
+      clear();
+
+      // Go back up the right column and make sure that the focus stays in that
+      // column, even though the second row is irregular.
+      expect(scope.focusInDirection(TraversalDirection.right), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, null, null, null, false, true]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.up), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, null, true, null, null, false]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.up), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, true, null, null, false, null, null, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.up), isFalse);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, null, null, null, null, null]));
+      clear();
+
+      // Go left on the top row and make sure that the focus stays in that
+      // row, even though the second column is irregular.
+      expect(scope.focusInDirection(TraversalDirection.left), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, true, false, null, null, null, null, null, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.left), isTrue);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[true, false, null, null, null, null, null, null, null]));
+      clear();
+
+      expect(scope.focusInDirection(TraversalDirection.left), isFalse);
+      await tester.pump();
+      expect(focus, orderedEquals(<bool?>[null, null, null, null, null, null, null, null, null]));
       clear();
     });
 
@@ -2129,7 +2234,7 @@ void main() {
       expect(node2.hasPrimaryFocus, isFalse);
     });
 
-    testWidgets("FocusTraversalGroup with skipTraversal for all descendents set to true doesn't cause an exception.", (WidgetTester tester) async {
+    testWidgets("FocusTraversalGroup with skipTraversal for all descendants set to true doesn't cause an exception.", (WidgetTester tester) async {
       final FocusNode node1 = FocusNode();
       final FocusNode node2 = FocusNode();
 
@@ -2352,3 +2457,15 @@ class TestRoute extends PageRouteBuilder<void> {
           },
         );
 }
+
+/// Used to test removal of nodes while sorting.
+class SkipAllButFirstAndLastPolicy extends FocusTraversalPolicy with DirectionalFocusTraversalPolicyMixin {
+  @override
+  Iterable<FocusNode> sortDescendants(Iterable<FocusNode> descendants, FocusNode currentNode) {
+    return <FocusNode>[
+      descendants.first,
+      if (currentNode != descendants.first && currentNode != descendants.last) currentNode,
+      descendants.last,
+    ];
+  }
+}