Skip to content

Commit

Permalink
Rework GoStructInitializationInspection
Browse files Browse the repository at this point in the history
fixes #2819
  • Loading branch information
wbars committed Dec 13, 2016
1 parent eb669fc commit 11bf872
Show file tree
Hide file tree
Showing 47 changed files with 602 additions and 71 deletions.
13 changes: 3 additions & 10 deletions src/com/goide/completion/GoStructLiteralCompletion.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import com.goide.psi.*;
import com.goide.psi.impl.GoPsiImplUtil;
import com.intellij.psi.PsiElement;
import com.intellij.util.ObjectUtils;
import com.intellij.util.containers.ContainerUtil;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -61,8 +59,8 @@ enum Variants {

@NotNull
static Variants allowedVariants(@Nullable GoReferenceExpression structFieldReference) {
GoValue value = parent(structFieldReference, GoValue.class);
GoElement element = parent(value, GoElement.class);
GoValue value = GoPsiTreeUtil.getDirectParentOfType(structFieldReference, GoValue.class);
GoElement element = GoPsiTreeUtil.getDirectParentOfType(value, GoElement.class);
if (element != null && element.getKey() != null) {
return Variants.NONE;
}
Expand All @@ -75,7 +73,7 @@ static Variants allowedVariants(@Nullable GoReferenceExpression structFieldRefer
boolean hasValueInitializers = false;
boolean hasFieldValueInitializers = false;

GoLiteralValue literalValue = parent(element, GoLiteralValue.class);
GoLiteralValue literalValue = GoPsiTreeUtil.getDirectParentOfType(element, GoLiteralValue.class);
List<GoElement> fieldInitializers = literalValue != null ? literalValue.getElementList() : Collections.emptyList();
for (GoElement initializer : fieldInitializers) {
if (initializer == element) {
Expand Down Expand Up @@ -105,9 +103,4 @@ static Set<String> alreadyAssignedFields(@Nullable GoLiteralValue literal) {
return identifier != null ? identifier.getText() : null;
});
}

@Contract("null,_->null")
private static <T> T parent(@Nullable PsiElement of, @NotNull Class<T> parentClass) {
return ObjectUtils.tryCast(of != null ? of.getParent() : null, parentClass);
}
}
147 changes: 102 additions & 45 deletions src/com/goide/inspections/GoStructInitializationInspection.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,32 @@
import com.goide.util.GoUtil;
import com.intellij.codeInspection.*;
import com.intellij.codeInspection.ui.SingleCheckboxOptionsPanel;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Comparing;
import com.intellij.openapi.util.InvalidDataException;
import com.intellij.openapi.util.WriteExternalException;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.ObjectUtils;
import org.jdom.Element;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.swing.*;
import java.util.List;

import static com.intellij.util.containers.ContainerUtil.*;
import static java.lang.Math.min;
import static java.util.stream.Collectors.toList;
import static java.util.stream.IntStream.range;

public class GoStructInitializationInspection extends GoInspectionBase {
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct field";
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct fields";
private static final GoReplaceWithNamedStructFieldQuickFix QUICK_FIX = new GoReplaceWithNamedStructFieldQuickFix();
public boolean reportLocalStructs;
/**
* @deprecated use reportLocalStructs
* @deprecated use {@link #reportLocalStructs}
*/
@SuppressWarnings("WeakerAccess") public Boolean reportImportedStructs;

Expand All @@ -49,67 +56,117 @@ public class GoStructInitializationInspection extends GoInspectionBase {
protected GoVisitor buildGoVisitor(@NotNull ProblemsHolder holder, @NotNull LocalInspectionToolSession session) {
return new GoVisitor() {
@Override
public void visitLiteralValue(@NotNull GoLiteralValue o) {
if (PsiTreeUtil.getParentOfType(o, GoReturnStatement.class, GoShortVarDeclaration.class, GoAssignmentStatement.class) == null) {
return;
}
PsiElement parent = o.getParent();
GoType refType = GoPsiImplUtil.getLiteralType(parent, false);
if (refType instanceof GoStructType) {
processStructType(holder, o, (GoStructType)refType);
public void visitLiteralValue(@NotNull GoLiteralValue literalValue) {
GoStructType structType = getLiteralStructType(literalValue);
if (structType == null || !isStructImportedOrLocalAllowed(structType, literalValue)) return;

List<GoElement> elements = literalValue.getElementList();
List<GoNamedElement> definitions = getFieldDefinitions(structType);

if (!areElementsKeysMatchesDefinitions(elements, definitions)) return;
registerProblemsForElementsWithoutKeys(elements, definitions.size());
}

private void registerProblemsForElementsWithoutKeys(@NotNull List<GoElement> elements, int definitionsCount) {
for (int i = 0; i < min(elements.size(), definitionsCount); i++) {
if (elements.get(i).getKey() != null) continue;
holder.registerProblem(elements.get(i), "Unnamed field initialization", ProblemHighlightType.WEAK_WARNING, QUICK_FIX);
}
}
};
}

@Override
public JComponent createOptionsPanel() {
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
}
@Contract("null -> null")
private static GoStructType getLiteralStructType(@Nullable GoLiteralValue literalValue) {
GoCompositeLit parentLit = GoPsiTreeUtil.getDirectParentOfType(literalValue, GoCompositeLit.class);
if (parentLit != null && !isStructLit(parentLit)) return null;

private void processStructType(@NotNull ProblemsHolder holder, @NotNull GoLiteralValue element, @NotNull GoStructType structType) {
if (reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), element.getContainingFile())) {
processLiteralValue(holder, element, structType.getFieldDeclarationList());
}
GoStructType litType = ObjectUtils.tryCast(GoPsiImplUtil.getLiteralType(literalValue, parentLit == null), GoStructType.class);
GoNamedElement definition = getFieldDefinition(GoPsiTreeUtil.getDirectParentOfType(literalValue, GoValue.class));
return definition != null && litType != null ? getUnderlyingStructType(definition.getGoType(null)) : litType;
}

private static void processLiteralValue(@NotNull ProblemsHolder holder,
@NotNull GoLiteralValue o,
@NotNull List<GoFieldDeclaration> fields) {
List<GoElement> vals = o.getElementList();
for (int elemId = 0; elemId < vals.size(); elemId++) {
ProgressManager.checkCanceled();
GoElement element = vals.get(elemId);
if (element.getKey() == null && elemId < fields.size()) {
String structFieldName = getFieldName(fields.get(elemId));
LocalQuickFix[] fixes = structFieldName != null ? new LocalQuickFix[]{new GoReplaceWithNamedStructFieldQuickFix(structFieldName)}
: LocalQuickFix.EMPTY_ARRAY;
holder.registerProblem(element, "Unnamed field initialization", ProblemHighlightType.GENERIC_ERROR_OR_WARNING, fixes);
}
}
@Nullable
private static GoNamedElement getFieldDefinition(@Nullable GoValue value) {
GoKey key = PsiTreeUtil.getPrevSiblingOfType(value, GoKey.class);
GoFieldName fieldName = key != null ? key.getFieldName() : null;
PsiElement field = fieldName != null ? fieldName.resolve() : null;
return field instanceof GoAnonymousFieldDefinition || field instanceof GoFieldDefinition ? ObjectUtils
.tryCast(field, GoNamedElement.class) : null;
}

@Nullable
private static String getFieldName(@NotNull GoFieldDeclaration declaration) {
List<GoFieldDefinition> list = declaration.getFieldDefinitionList();
GoFieldDefinition fieldDefinition = ContainerUtil.getFirstItem(list);
return fieldDefinition != null ? fieldDefinition.getIdentifier().getText() : null;
@Contract("null -> null")
private static GoStructType getUnderlyingStructType(@Nullable GoType type) {
return type != null ? ObjectUtils.tryCast(type.getUnderlyingType(), GoStructType.class) : null;
}

private static boolean isStructLit(@NotNull GoCompositeLit parentLit) {
return getUnderlyingStructType(parentLit.getGoType(null)) != null;
}

private boolean isStructImportedOrLocalAllowed(@NotNull GoStructType structType, @NotNull GoLiteralValue literalValue) {
return reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), literalValue.getContainingFile());
}

private static boolean areElementsKeysMatchesDefinitions(@NotNull List<GoElement> elements,
@NotNull List<GoNamedElement> fieldDefinitions) {
return range(0, elements.size())
.allMatch(i -> isNullOrNamesEqual(elements.get(i).getKey(), GoPsiImplUtil.getByIndex(fieldDefinitions, i)));
}


@Contract("null, _ -> true")
private static boolean isNullOrNamesEqual(@Nullable GoKey key, @Nullable GoNamedElement elementToCompare) {
return key == null || elementToCompare != null && Comparing.equal(key.getText(), elementToCompare.getName());
}

@NotNull
private static List<GoNamedElement> getFieldDefinitions(@Nullable GoStructType type) {
return type != null ? type.getFieldDeclarationList().stream()
.flatMap(declaration -> getFieldDefinitions(declaration).stream())
.collect(toList()) : emptyList();
}

@NotNull
private static List<GoNamedElement> getFieldDefinitions(@NotNull GoFieldDeclaration declaration) {
GoNamedElement anonymousDefinition = ObjectUtils.tryCast(declaration.getAnonymousFieldDefinition(), GoNamedElement.class);
return anonymousDefinition != null
? list(anonymousDefinition)
: map(declaration.getFieldDefinitionList(), definition -> ObjectUtils.tryCast(definition, GoNamedElement.class));
}

@Override
public JComponent createOptionsPanel() {
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
}

private static class GoReplaceWithNamedStructFieldQuickFix extends LocalQuickFixBase {
private String myStructField;

public GoReplaceWithNamedStructFieldQuickFix(@NotNull String structField) {
public GoReplaceWithNamedStructFieldQuickFix() {
super(REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME);
myStructField = structField;
}

@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
PsiElement startElement = descriptor.getStartElement();
if (startElement instanceof GoElement) {
startElement.replace(GoElementFactory.createLiteralValueElement(project, myStructField, startElement.getText()));
}
PsiElement element = ObjectUtils.tryCast(descriptor.getStartElement(), GoElement.class);
GoLiteralValue literal = element != null && element.isValid() ? PsiTreeUtil.getParentOfType(element, GoLiteralValue.class) : null;

List<GoElement> elements = literal != null ? literal.getElementList() : emptyList();
List<GoNamedElement> fieldDefinitionNames = getFieldDefinitions(getLiteralStructType(literal));
if (!areElementsKeysMatchesDefinitions(elements, fieldDefinitionNames)) return;
addKeysToElements(project, elements, fieldDefinitionNames);
}
}

private static void addKeysToElements(@NotNull Project project,
@NotNull List<GoElement> elements,
@NotNull List<GoNamedElement> fieldDefinitions) {
for (int i = 0; i < min(elements.size(), fieldDefinitions.size()); i++) {
GoElement element = elements.get(i);
String fieldDefinitionName = fieldDefinitions.get(i).getName();
GoValue value = fieldDefinitionName != null && element.getKey() == null ? element.getValue() : null;
if (value != null) element.replace(GoElementFactory.createLiteralValueElement(project, fieldDefinitionName, value.getText()));
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/com/goide/psi/GoPsiTreeUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import com.intellij.psi.stubs.StubElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.PsiUtilCore;
import com.intellij.util.ObjectUtils;
import com.intellij.util.SmartList;
import com.intellij.util.containers.ContainerUtil;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -155,5 +157,10 @@ private static PsiElement findNotWhiteSpaceElementAtOffset(@NotNull GoFile file,
}
return element;
}

@Contract("null,_->null")
public static <T> T getDirectParentOfType(@Nullable PsiElement element, @NotNull Class<T> aClass) {
return element != null ? ObjectUtils.tryCast(element.getParent(), aClass) : null;
}
}

2 changes: 1 addition & 1 deletion src/com/goide/psi/impl/GoElementFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public static GoType createType(@NotNull Project project, @NotNull String text)
return PsiTreeUtil.findChildOfType(file, GoType.class);
}

public static PsiElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
public static GoElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
GoFile file = createFileFromText(project, "package a; var _ = struct { a string } { " + key + ": " + value + " }");
return PsiTreeUtil.findChildOfType(file, GoElement.class);
}
Expand Down
6 changes: 4 additions & 2 deletions src/com/goide/psi/impl/GoPsiImplUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,12 @@ public static GoType getLiteralType(@Nullable PsiElement context, boolean consid
@Nullable
public static GoValue getParentGoValue(@NotNull PsiElement element) {
PsiElement place = element;
while ((place = PsiTreeUtil.getParentOfType(place, GoLiteralValue.class)) != null) {
do {
if (place.getParent() instanceof GoValue) {
return (GoValue)place.getParent();
}
}
while ((place = PsiTreeUtil.getParentOfType(place, GoLiteralValue.class)) != null);
return null;
}

Expand Down Expand Up @@ -1468,7 +1469,8 @@ public static GoExpression getValue(@NotNull GoConstDefinition definition) {
return getByIndex(((GoConstSpec)parent).getExpressionList(), index);
}

private static <T> T getByIndex(@NotNull List<T> list, int index) {
@Nullable
public static <T> T getByIndex(@NotNull List<T> list, int index) {
return 0 <= index && index < list.size() ? list.get(index) : null;
}

Expand Down
11 changes: 11 additions & 0 deletions testData/inspections/struct-initialization/anonField-after.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo

type S struct {
X string
string
Y int
}
func main() {
var s S
s = S{X: "X", string: "a", Y: 1}
}
11 changes: 11 additions & 0 deletions testData/inspections/struct-initialization/anonField.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo

type S struct {
X string
string
Y int
}
func main() {
var s S
s = S{<caret><weak_warning descr="Unnamed field initialization">"X"</weak_warning>, <weak_warning descr="Unnamed field initialization">"a"</weak_warning>, Y: 1}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{X: 1, Y: 0, 2}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{<weak_warning descr="Unnamed field initialization"><caret>1</weak_warning>, <weak_warning descr="Unnamed field initialization">0</weak_warning>, 2}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo

type S struct {
X, Y int
}
func main() {
s := S{<caret>1, 0, X: 2}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo

type S struct {
t int
}

func main() {
var _ = []S{ {t: 1} }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package foo

type S struct {
t int
}

func main() {
var _ = []S{ {<weak_warning descr="Unnamed field initialization"><caret>1</weak_warning>} }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package foo

func main() {
type B struct {
Y int
}

type S struct {
X int
B
Z int
}

s := S{X: 1, B: B{Y: 2}, Z: 3}
print(s.B.Y)
}
16 changes: 16 additions & 0 deletions testData/inspections/struct-initialization/innerAnonStruct.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package foo

func main() {
type B struct {
Y int
}

type S struct {
X int
B
Z int
}

s := S{<weak_warning descr="Unnamed field initialization">1<caret></weak_warning>, <weak_warning descr="Unnamed field initialization">B{Y: 2}</weak_warning>, <weak_warning descr="Unnamed field initialization">3</weak_warning>}
print(s.B.Y)
}
Loading

0 comments on commit 11bf872

Please sign in to comment.