Skip to content

Commit

Permalink
Merge pull request #376 from kusumotolab/fix-implementation-in-GA
Browse files Browse the repository at this point in the history
GAの挙動を変更
  • Loading branch information
a3636tako authored Oct 31, 2018
2 parents 24fa001 + 809ac45 commit bedfbb9
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 74 deletions.
23 changes: 15 additions & 8 deletions src/main/java/jp/kusumotolab/kgenprog/ga/RandomMutation.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ public List<Variant> exec(final VariantStore variantStore) {

final List<Variant> generatedVariants = new ArrayList<>();

for (final Variant variant : variantStore.getCurrentVariants()) {
final List<Variant> currentVariants = variantStore.getCurrentVariants();

final Roulette<Variant> variantRoulette = new Roulette<>(currentVariants, e -> {
final Fitness fitness = e.getFitness();
final double value = fitness.getValue();
return Double.isNaN(value) ? 0 : value + 1;
}, random);

for (int i = 0; i < numberOfBase; i++) {
final Variant variant = variantRoulette.exec();
final List<Suspiciousness> suspiciousnesses = variant.getSuspiciousnesses();
final Function<Suspiciousness, Double> weightFunction = susp -> Math.pow(susp.getValue(), 2);

Expand All @@ -40,13 +49,11 @@ public List<Variant> exec(final VariantStore variantStore) {
final Roulette<Suspiciousness> roulette =
new Roulette<>(suspiciousnesses, weightFunction, random);

for (int i = 0; i < numberOfBase; i++) {
final Suspiciousness suspiciousness = roulette.exec();
final Base base = makeBase(suspiciousness);
final Gene gene = makeGene(variant.getGene(), base);
final HistoricalElement element = new MutationHistoricalElement(variant, base);
generatedVariants.add(variantStore.createVariant(gene, element));
}
final Suspiciousness suspiciousness = roulette.exec();
final Base base = makeBase(suspiciousness);
final Gene gene = makeGene(variant.getGene(), base);
final HistoricalElement element = new MutationHistoricalElement(variant, base);
generatedVariants.add(variantStore.createVariant(gene, element));
}

log.debug("exit exec(VariantStore)");
Expand Down
25 changes: 16 additions & 9 deletions src/main/java/jp/kusumotolab/kgenprog/ga/SinglePointCrossover.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -16,11 +14,11 @@ public class SinglePointCrossover implements Crossover {
private static Logger log = LoggerFactory.getLogger(SinglePointCrossover.class);

private final Random random;
private final int numberOfPair;
private final int crossoverGeneratingCount;

public SinglePointCrossover(final Random random, final int numberOfPair) {
public SinglePointCrossover(final Random random, final int crossoverGeneratingCount) {
this.random = random;
this.numberOfPair = numberOfPair;
this.crossoverGeneratingCount = crossoverGeneratingCount;
}

public SinglePointCrossover(final Random random) {
Expand All @@ -42,10 +40,19 @@ public List<Variant> exec(final VariantStore variantStore) {
return Collections.emptyList();
}

return IntStream.range(0, numberOfPair)
.mapToObj(e -> makeVariants(filteredVariants, variantStore))
.flatMap(Collection::stream)
.collect(Collectors.toList());
final List<Variant> variants = new ArrayList<>();

for (int i = 0; i < crossoverGeneratingCount/2; i++) {
final List<Variant> newVariants = makeVariants(filteredVariants, variantStore);
variants.addAll(newVariants);
}
if (crossoverGeneratingCount != 0 && crossoverGeneratingCount % 2 != 0) {
final List<Variant> newVariants = makeVariants(filteredVariants, variantStore);
if (!newVariants.isEmpty()) {
variants.add(newVariants.get(0));
}
}
return variants;
}

private List<Variant> makeVariants(final List<Variant> variants, final VariantStore store) {
Expand Down
186 changes: 135 additions & 51 deletions src/test/java/jp/kusumotolab/kgenprog/ga/RandomMutationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -25,6 +24,7 @@
import jp.kusumotolab.kgenprog.project.GeneratedSourceCode;
import jp.kusumotolab.kgenprog.project.Operation;
import jp.kusumotolab.kgenprog.project.ProductSourcePath;
import jp.kusumotolab.kgenprog.project.SourcePath;
import jp.kusumotolab.kgenprog.project.factory.TargetProject;
import jp.kusumotolab.kgenprog.project.factory.TargetProjectFactory;
import jp.kusumotolab.kgenprog.project.jdt.GeneratedJDTAST;
Expand All @@ -37,6 +37,11 @@ public class RandomMutationTest {
@SuppressWarnings("serial")
private class MockRandom extends Random {

MockRandom(final long seed) {
super();
setSeed(seed);
}

@Override
public int nextInt() {
return 0;
Expand All @@ -47,76 +52,73 @@ public int nextInt(final int divisor) {
return 1;
}

@Override
public double nextDouble() {
return 1.0d;
}

@Override
public boolean nextBoolean() {
return true;
}
}

@Test
public void testExec() throws NoSuchFieldException, IllegalAccessException {
final Path basePath = Paths.get("example/BuildSuccess01");
final TargetProject targetProject = TargetProjectFactory.create(basePath);
final GeneratedSourceCode sourceCode = TestUtil.createGeneratedSourceCode(targetProject);
final Random random = new MockRandom();
random.setSeed(0);
final CandidateSelection statementSelection = new RouletteStatementSelection(random);
final RandomMutation randomMutation = new RandomMutation(15, random, statementSelection);
randomMutation.setCandidates(sourceCode.getProductAsts());
public void testGeneratedVariantsSize() {
final GeneratedSourceCode generatedSourceCode = createGeneratedSourceCode();

final GeneratedAST<ProductSourcePath> generatedAST =
new ArrayList<>(sourceCode.getProductAsts()).get(0);
final ProductSourcePath sourcePath = generatedAST.getSourcePath();
final CompilationUnit root =
(CompilationUnit) ((GeneratedJDTAST<ProductSourcePath>) generatedAST).getRoot()
.getRoot()
.getRoot();
final TypeDeclaration typeRoot = (TypeDeclaration) root.types()
.get(0);

@SuppressWarnings("unchecked")
final List<Statement> statements = typeRoot.getMethods()[0].getBody()
.statements();
final Variant initialVariant = createInitialVariant(generatedSourceCode);
final VariantStore variantStore = createVariantStore(initialVariant);

final double[] value = {0.8};
final List<Suspiciousness> suspiciousnesses = statements.stream()
.map(e -> new JDTASTLocation(sourcePath, e))
.map(e -> {
value[0] += 0.1;
return new Suspiciousness(e, value[0]);
})
.collect(Collectors.toList());
final RandomMutation randomMutation = createRandomMutation(generatedSourceCode);
final List<Variant> variantList = randomMutation.exec(variantStore);

final Gene initialGene = new Gene(Collections.emptyList());
final Variant initialVariant =
new Variant(0, initialGene, null, null, null, suspiciousnesses, null);
final VariantStore variantStore = mock(VariantStore.class);
when(variantStore.getCurrentVariants()).thenReturn(Arrays.asList(initialVariant));
when(variantStore.createVariant(any(), any())).then(ans -> {
return new Variant(0, ans.getArgument(0), null, null, null, null, ans.getArgument(1));
});
assertThat(variantList).hasSize(15);
}

// 正しく15個のVariantが生成されるかのテスト
@Test
public void testBias() {
final GeneratedSourceCode generatedSourceCode = createGeneratedSourceCode();

final Variant initialVariant = createInitialVariant(generatedSourceCode);
final List<Suspiciousness> suspiciousnesses = initialVariant.getSuspiciousnesses();

final VariantStore variantStore = createVariantStore(initialVariant);
final RandomMutation randomMutation = createRandomMutation(generatedSourceCode, new Random(0));
final List<Variant> variantList = randomMutation.exec(variantStore);
assertThat(variantList).hasSize(15);

// Suspiciousnessが高い場所ほど多くの操作が生成されているかのテスト
final Map<String, List<Base>> map = variantList.stream()
.map(this::getLastBase)
.collect(
Collectors.groupingBy(e -> ((JDTASTLocation) e.getTargetLocation()).node.toString()));

final String weakSuspiciousness = ((JDTASTLocation) suspiciousnesses.get(0)
.getLocation()).node.toString();
final String strongSuspiciousness = ((JDTASTLocation) suspiciousnesses.get(1)
.getLocation()).node.toString();

assertThat(map.get(weakSuspiciousness)
.size()).isLessThan(map.get(strongSuspiciousness)
.size());
final List<Base> weakBases = map.get(weakSuspiciousness);
final List<Base> strongBases = map.get(strongSuspiciousness);

assertThat(weakBases.size())
.isLessThan(strongBases.size());
}

@Test
public void testGeneratedOperation() throws NoSuchFieldException, IllegalAccessException {
final GeneratedSourceCode generatedSourceCode = createGeneratedSourceCode();

final Variant initialVariant = createInitialVariant(generatedSourceCode);
final VariantStore variantStore = createVariantStore(initialVariant);

final RandomMutation randomMutation = createRandomMutation(generatedSourceCode);
final List<Variant> variantList = randomMutation.exec(variantStore);

// TestNumberGenerationにしたがってOperationが生成されているかのテスト
final Variant variant = variantList.get(0);
final Base base = getLastBase(variant);
final Gene gene = variant.getGene();
final List<Base> bases = gene.getBases();
final Base base = bases.get(0);

final JDTASTLocation targetLocation = (JDTASTLocation) base.getTargetLocation();
assertThat(targetLocation.node).isSameSourceCodeAs("return n;");

Expand All @@ -128,15 +130,97 @@ public void testExec() throws NoSuchFieldException, IllegalAccessException {
.getDeclaredField("astNode");
field.setAccessible(true);
final ASTNode node = (ASTNode) field.get(insertOperation);
assertThat(node).isSameSourceCodeAs("n--;");
assertThat(node).isSameSourceCodeAs("return n;");

}

@Test
public void testHistoricalElement() {
final GeneratedSourceCode generatedSourceCode = createGeneratedSourceCode();

final Variant initialVariant = createInitialVariant(generatedSourceCode);
final VariantStore variantStore = createVariantStore(initialVariant);

final RandomMutation randomMutation = createRandomMutation(generatedSourceCode);
final List<Variant> variantList = randomMutation.exec(variantStore);

final Variant variant = variantList.get(0);
final Base base = getLastBase(variant);

// HistoricalELementのテスト
final HistoricalElement element = variant.getHistoricalElement();
final List<Variant> parents = element.getParents();
assertThat(element).isInstanceOf(MutationHistoricalElement.class);

final MutationHistoricalElement mElement = (MutationHistoricalElement) element;
assertThat(element.getParents()).hasSize(1)
final Base appendedBase = mElement.getAppendedBase();
assertThat(parents).hasSize(1)
.containsExactly(initialVariant);
assertThat(mElement.getAppendedBase()).isEqualTo(base);

assertThat(appendedBase).isEqualTo(base);
}


private GeneratedSourceCode createGeneratedSourceCode() {
final Path basePath = Paths.get("example/BuildSuccess01");
final TargetProject targetProject = TargetProjectFactory.create(basePath);
return TestUtil.createGeneratedSourceCode(targetProject);
}

private RandomMutation createRandomMutation(final GeneratedSourceCode sourceCode) {
final Random random = new MockRandom(0);
return createRandomMutation(sourceCode, random);
}

private RandomMutation createRandomMutation(final GeneratedSourceCode sourceCode,
final Random random) {
final CandidateSelection statementSelection = new RouletteStatementSelection(random);
final RandomMutation randomMutation = new RandomMutation(15, random, statementSelection);
randomMutation.setCandidates(sourceCode.getProductAsts());
return randomMutation;
}

private GeneratedAST<ProductSourcePath> createGeneratedAST(final GeneratedSourceCode sourceCode) {
final List<GeneratedAST<ProductSourcePath>> asts = sourceCode.getProductAsts();
return asts.get(0);
}

@SuppressWarnings("unchecked")
private List<Statement> createStatement(final GeneratedAST<ProductSourcePath> generatedAST) {
final CompilationUnit root = ((GeneratedJDTAST<ProductSourcePath>) generatedAST).getRoot();
final List<TypeDeclaration> types = root.types();
final TypeDeclaration typeRoot = types.get(0);

return (List<Statement>) typeRoot.getMethods()[0].getBody()
.statements();
}

private Variant createInitialVariant(final GeneratedSourceCode sourceCode) {
final GeneratedAST<ProductSourcePath> generatedAST = createGeneratedAST(sourceCode);

final List<Statement> statements = createStatement(generatedAST);
final SourcePath sourcePath = generatedAST.getSourcePath();

final List<Suspiciousness> suspiciousnesses = new ArrayList<>();
double susValue = 0.0;
for (final Statement statement : statements) {
susValue += 1.0 / statements.size();
final JDTASTLocation location = new JDTASTLocation(sourcePath, statement);
final Suspiciousness suspiciousness = new Suspiciousness(location, susValue);
suspiciousnesses.add(suspiciousness);
}

final Gene initialGene = new Gene(Collections.emptyList());
return new Variant(0, initialGene, sourceCode, null, new SimpleFitness(0.0), suspiciousnesses,
null);
}

private VariantStore createVariantStore(final Variant initialVariant) {
final VariantStore variantStore = mock(VariantStore.class);
when(variantStore.getCurrentVariants()).thenReturn(Collections.singletonList(initialVariant));
when(variantStore.createVariant(any(), any())).then(ans -> {
return new Variant(0, ans.getArgument(0), null, null, null, null, ans.getArgument(1));
});
return variantStore;
}

private Base getLastBase(final Variant variant) {
Expand Down
Loading

0 comments on commit bedfbb9

Please sign in to comment.