@@ -227,6 +227,34 @@ JCClassDecl classDef(ClassSymbol c) {
227227 return def ;
228228 }
229229
230+ /**
231+ * Get the enum constants for the given enum class symbol, if known.
232+ * They will only be found if they are defined within the same top-level
233+ * class as the class being compiled, so it's safe to assume that they
234+ * can't change at runtime due to a recompilation.
235+ */
236+ List <Name > enumNamesFor (ClassSymbol c ) {
237+
238+ // Find the class definition and verify it is an enum class
239+ final JCClassDecl classDef = classDef (c );
240+ if (classDef == null ||
241+ (classDef .mods .flags & ENUM ) == 0 ||
242+ (types .supertype (currentClass .type ).tsym .flags () & ENUM ) != 0 ) {
243+ return null ;
244+ }
245+
246+ // Gather the enum identifiers
247+ ListBuffer <Name > idents = new ListBuffer <>();
248+ for (List <JCTree > defs = classDef .defs ; defs .nonEmpty (); defs =defs .tail ) {
249+ if (defs .head .hasTag (VARDEF ) &&
250+ (((JCVariableDecl ) defs .head ).mods .flags & ENUM ) != 0 ) {
251+ JCVariableDecl var = (JCVariableDecl )defs .head ;
252+ idents .append (var .name );
253+ }
254+ }
255+ return idents .toList ();
256+ }
257+
230258 /** A hash table mapping class symbols to lists of free variables.
231259 * accessed by them. Only free variables of the method immediately containing
232260 * a class are associated with that class.
@@ -427,14 +455,62 @@ List<VarSymbol> freevars(ClassSymbol c) {
427455 Map <TypeSymbol ,EnumMapping > enumSwitchMap = new LinkedHashMap <>();
428456
429457 EnumMapping mapForEnum (DiagnosticPosition pos , TypeSymbol enumClass ) {
430- EnumMapping map = enumSwitchMap .get (enumClass );
431- if (map == null )
432- enumSwitchMap .put (enumClass , map = new EnumMapping (pos , enumClass ));
433- return map ;
458+
459+ // If enum class is part of this compilation, just switch on ordinal value
460+ if (enumClass .kind == TYP ) {
461+ final List <Name > idents = enumNamesFor ((ClassSymbol )enumClass );
462+ if (idents != null )
463+ return new CompileTimeEnumMapping (idents );
464+ }
465+
466+ // Map identifiers to ordinal values at runtime, and then switch on that
467+ return enumSwitchMap .computeIfAbsent (enumClass , ec -> new RuntimeEnumMapping (pos , ec ));
434468 }
435469
436- /** This map gives a translation table to be used for enum
437- * switches.
470+ /** Generates a test value and corresponding cases for a switch on an enum type.
471+ */
472+ interface EnumMapping {
473+
474+ /** Given an expression for the enum value's ordinal, generate an expression for the switch statement.
475+ */
476+ JCExpression switchValue (JCExpression ordinalExpr );
477+
478+ /** Generate the switch statement case value corresponding to the given enum value.
479+ */
480+ JCLiteral caseValue (VarSymbol v );
481+
482+ default void translate () {
483+ }
484+ }
485+
486+ /** EnumMapping using compile-time constants. Only valid when compiling the enum class itself,
487+ * because otherwise the ordinals we use could become obsolete if/when the enum class is recompiled.
488+ */
489+ class CompileTimeEnumMapping implements EnumMapping {
490+
491+ final List <Name > enumNames ;
492+
493+ CompileTimeEnumMapping (List <Name > enumNames ) {
494+ Assert .check (enumNames != null );
495+ this .enumNames = enumNames ;
496+ }
497+
498+ @ Override
499+ public JCExpression switchValue (JCExpression ordinalExpr ) {
500+ return ordinalExpr ;
501+ }
502+
503+ @ Override
504+ public JCLiteral caseValue (VarSymbol v ) {
505+ final int ordinal = enumNames .indexOf (v .name );
506+ Assert .check (ordinal != -1 );
507+ return make .Literal (ordinal );
508+ }
509+ }
510+
511+ /** EnumMapping using run-time ordinal lookup.
512+ *
513+ * This builds a translation table to be used for enum switches.
438514 *
439515 * <p>For each enum that appears as the type of a switch
440516 * expression, we maintain an EnumMapping to assist in the
@@ -466,8 +542,8 @@ EnumMapping mapForEnum(DiagnosticPosition pos, TypeSymbol enumClass) {
466542 * </pre>
467543 * class EnumMapping provides mapping data and support methods for this translation.
468544 */
469- class EnumMapping {
470- EnumMapping (DiagnosticPosition pos , TypeSymbol forEnum ) {
545+ class RuntimeEnumMapping implements EnumMapping {
546+ RuntimeEnumMapping (DiagnosticPosition pos , TypeSymbol forEnum ) {
471547 this .forEnum = forEnum ;
472548 this .values = new LinkedHashMap <>();
473549 this .pos = pos ;
@@ -500,15 +576,22 @@ class EnumMapping {
500576 // the mapped values
501577 final Map <VarSymbol ,Integer > values ;
502578
503- JCLiteral forConstant (VarSymbol v ) {
579+ @ Override
580+ public JCExpression switchValue (JCExpression ordinalExpr ) {
581+ return make .Indexed (mapVar , ordinalExpr );
582+ }
583+
584+ @ Override
585+ public JCLiteral caseValue (VarSymbol v ) {
504586 Integer result = values .get (v );
505587 if (result == null )
506588 values .put (v , result = next ++);
507589 return make .Literal (result );
508590 }
509591
510592 // generate the field initializer for the map
511- void translate () {
593+ @ Override
594+ public void translate () {
512595 boolean prevAllowProtectedAccess = attrEnv .info .allowProtectedAccess ;
513596 try {
514597 make .at (pos .getStartPosition ());
@@ -3760,7 +3843,7 @@ public JCTree visitEnumSwitch(JCTree tree, JCExpression selector, List<JCCase> c
37603843 selector .type ,
37613844 currentMethodSym );
37623845 JCStatement var = make .at (tree .pos ()).VarDef (dollar_s , selector ).setType (dollar_s .type );
3763- newSelector = make . Indexed ( map .mapVar ,
3846+ newSelector = map .switchValue (
37643847 make .App (make .Select (make .Ident (dollar_s ),
37653848 ordinalMethod )));
37663849 newSelector =
@@ -3771,7 +3854,7 @@ public JCTree visitEnumSwitch(JCTree tree, JCExpression selector, List<JCCase> c
37713854 .setType (newSelector .type ))
37723855 .setType (newSelector .type );
37733856 } else {
3774- newSelector = make . Indexed ( map .mapVar ,
3857+ newSelector = map .switchValue (
37753858 make .App (make .Select (selector ,
37763859 ordinalMethod )));
37773860 }
@@ -3783,7 +3866,7 @@ public JCTree visitEnumSwitch(JCTree tree, JCExpression selector, List<JCCase> c
37833866 pat = makeLit (syms .intType , -1 );
37843867 } else {
37853868 VarSymbol label = (VarSymbol )TreeInfo .symbol (((JCConstantCaseLabel ) c .labels .head ).expr );
3786- pat = map .forConstant (label );
3869+ pat = map .caseValue (label );
37873870 }
37883871 newCases .append (make .Case (JCCase .STATEMENT , List .of (make .ConstantCaseLabel (pat )), c .stats , null ));
37893872 } else {
0 commit comments