Skip to content

Commit

Permalink
fix(ai): Stop turn when AI throws exception
Browse files Browse the repository at this point in the history
  • Loading branch information
vincent4vx committed May 20, 2024
1 parent 1a29441 commit 2dbb4c6
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void dispatch(Object event) {
for (Listener listener : container) {
try {
listener.on(event);
} catch (RuntimeException e) {
} catch (Exception e) {
logger.error("Error during execution of listener " + listener.getClass().getName(), e);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/fr/quatrevieux/araknemu/game/GameModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ private void configureServices(ContainerConfigurator configurator) {
StatesModule::new,
RaulebaqueModule::new,
LaunchedSpellsModule::new,
fight -> new AiModule(container.get(AiFactory.class)),
fight -> new AiModule(container.get(AiFactory.class), fight, container.get(Logger.class)),
fight -> new MonsterInvocationModule(container.get(MonsterService.class), container.get(FighterFactory.class), fight),
SpiritualLeashModule::new,
CarryingModule::new,
Expand Down
37 changes: 21 additions & 16 deletions src/main/java/fr/quatrevieux/araknemu/game/fight/ai/FighterAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,34 @@ public void run() {
}

final Turn<FightAction> currentTurn = turn;
boolean stop = true;

if (!currentTurn.active()) {
turn = null;
return;
}

final Optional<FightAction> action = generator.generate(
this,
new FightAiActionFactoryAdapter(
fighter,
fight,
fight.actions()
)
);

if (action.isPresent()) {
currentTurn.perform(action.get());
currentTurn.later(() -> fight.schedule(this, Duration.ofMillis(800)));
return;
try {
final Optional<FightAction> action = generator.generate(
this,
new FightAiActionFactoryAdapter(
fighter,
fight,
fight.actions()
)
);

if (action.isPresent()) {
currentTurn.perform(action.get());
currentTurn.later(() -> fight.schedule(this, Duration.ofMillis(800)));
stop = false;
}
} finally {
if (stop) {
turn = null;
currentTurn.stop();
}
}

turn = null;
currentTurn.stop();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,28 @@
package fr.quatrevieux.araknemu.game.fight.module;

import fr.quatrevieux.araknemu.core.event.Listener;
import fr.quatrevieux.araknemu.game.fight.Fight;
import fr.quatrevieux.araknemu.game.fight.ai.FighterAI;
import fr.quatrevieux.araknemu.game.fight.ai.factory.AiFactory;
import fr.quatrevieux.araknemu.game.fight.fighter.Fighter;
import fr.quatrevieux.araknemu.game.fight.fighter.PlayableFighter;
import fr.quatrevieux.araknemu.game.fight.fighter.event.FighterInitialized;
import fr.quatrevieux.araknemu.game.fight.turn.FightTurn;
import fr.quatrevieux.araknemu.game.fight.turn.event.TurnStarted;
import org.apache.logging.log4j.Logger;

/**
* Fight module for enable AI
*/
public final class AiModule implements FightModule {
private final AiFactory<PlayableFighter> factory;
private final Fight fight;
private final Logger logger;

public AiModule(AiFactory<PlayableFighter> factory) {
public AiModule(AiFactory<PlayableFighter> factory, Fight fight, Logger logger) {
this.factory = factory;
this.fight = fight;
this.logger = logger;
}

@Override
Expand Down Expand Up @@ -74,7 +80,10 @@ public Class<TurnStarted> event() {
* Initialize the AI for the fighter (if supported)
*/
private void init(PlayableFighter fighter) {
factory.create(fighter).ifPresent(fighter::attach);
factory.create(fighter).ifPresent(ai -> {
logger.debug("AI initialized for {}", fighter);
fighter.attach(ai);
});
}

/**
Expand All @@ -83,8 +92,18 @@ private void init(PlayableFighter fighter) {
private void start(FightTurn turn) {
final FighterAI ai = turn.fighter().attachment(FighterAI.class);

if (ai != null) {
if (ai == null) {
return;
}

logger.debug("Starting AI for {}", turn.fighter());

try {
ai.start(turn);
} catch (Exception e) {
logger.error("Error during AI execution. Stop the turn.", e);
// Should be done asynchronously because the turn is not totally started
fight.execute(turn::stop);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ public boolean start() {

fighter.play(this);
active.set(true);
fight.dispatch(new TurnStarted(this));
timer = fight.schedule(this::stop, duration);
fight.dispatch(new TurnStarted(this));

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import fr.quatrevieux.araknemu.game.player.spell.SpellBook;
import fr.quatrevieux.araknemu.game.spell.SpellService;
import fr.quatrevieux.araknemu.game.spell.effect.SpellEffect;
import org.apache.logging.log4j.Logger;
import org.mockito.Mockito;

import java.lang.reflect.Field;
import java.util.Map;
Expand Down Expand Up @@ -78,7 +80,7 @@ public void configureFight(Consumer<FightBuilder> configurator) {
configurator.accept(builder);

fight = builder.build(true);
fight.register(new AiModule(new ChainAiFactory()));
fight.register(new AiModule(new ChainAiFactory(), fight, Mockito.mock(Logger.class)));
fight.register(new CommonEffectsModule(fight));

fighter = player.fighter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,24 @@
import fr.quatrevieux.araknemu.game.fight.ai.action.logic.NullGenerator;
import fr.quatrevieux.araknemu.game.fight.ai.memory.MemoryKey;
import fr.quatrevieux.araknemu.game.fight.ai.util.AIHelper;
import fr.quatrevieux.araknemu.game.fight.fighter.ActiveFighter;
import fr.quatrevieux.araknemu.game.fight.fighter.Fighter;
import fr.quatrevieux.araknemu.game.fight.fighter.PlayableFighter;
import fr.quatrevieux.araknemu.game.fight.fighter.invocation.DoubleFighter;
import fr.quatrevieux.araknemu.game.fight.fighter.player.PlayerFighter;
import fr.quatrevieux.araknemu.game.fight.state.PlacementState;
import fr.quatrevieux.araknemu.game.fight.turn.FightTurn;
import fr.quatrevieux.araknemu.game.fight.turn.action.Action;
import fr.quatrevieux.araknemu.game.fight.turn.action.ActionResult;
import fr.quatrevieux.araknemu.game.fight.turn.action.ActionType;
import fr.quatrevieux.araknemu.game.fight.turn.action.FightAction;
import io.github.artsok.RepeatedIfExceptionsTest;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.time.Duration;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -195,7 +200,10 @@ void startUnit() throws InterruptedException {

FighterAI ai = new FighterAI(fighter, fight, new GeneratorAggregate(new ActionGenerator[] {generator1, generator2}));

Mockito.when(generator1.generate(Mockito.eq(ai), Mockito.any(AiActionFactory.class))).thenReturn(Optional.of(Mockito.mock(Action.class)));
FightAction action = Mockito.mock(FightAction.class);
Mockito.when(action.validate(Mockito.any())).thenReturn(true);
Mockito.when(action.start()).thenReturn(Mockito.mock(ActionResult.class));
Mockito.when(generator1.generate(Mockito.eq(ai), Mockito.any(AiActionFactory.class))).thenReturn(Optional.of(action));

ai.start(turn);

Expand All @@ -209,6 +217,25 @@ void startUnit() throws InterruptedException {
assertTrue(turn.active());
}

@RepeatedIfExceptionsTest
void startWithExceptionShouldStopTurn() throws InterruptedException {
ActionGenerator generator1 = Mockito.mock(ActionGenerator.class);
Mockito.when(generator1.generate(Mockito.any(), Mockito.any())).thenThrow(new RuntimeException("Test"));

fight.turnList().start();
FightTurn turn = fight.turnList().current().get();

FighterAI ai = new FighterAI(fighter, fight, new GeneratorAggregate(new ActionGenerator[] {generator1}));

ai.start(turn);

Mockito.verify(generator1).initialize(ai);

Mockito.verify(generator1).generate(Mockito.eq(ai), Mockito.any(AiActionFactory.class));

assertFalse(turn.active());
}

@RepeatedIfExceptionsTest
void startShouldCallMemoryRefresh() throws InterruptedException {
ActionGenerator generator = Mockito.mock(ActionGenerator.class);
Expand All @@ -222,7 +249,10 @@ void startShouldCallMemoryRefresh() throws InterruptedException {

ai.set(key, value);

Mockito.when(generator.generate(Mockito.eq(ai), Mockito.any(AiActionFactory.class))).thenReturn(Optional.of(Mockito.mock(Action.class)));
FightAction action = Mockito.mock(FightAction.class);
Mockito.when(action.validate(Mockito.any())).thenReturn(true);
Mockito.when(action.start()).thenReturn(Mockito.mock(ActionResult.class));
Mockito.when(generator.generate(Mockito.eq(ai), Mockito.any(AiActionFactory.class))).thenReturn(Optional.of(action));

ai.start(turn);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import fr.quatrevieux.araknemu.game.fight.ai.AI;
import fr.quatrevieux.araknemu.game.fight.fighter.Fighter;
import fr.quatrevieux.araknemu.game.fight.fighter.PlayableFighter;
import fr.quatrevieux.araknemu.game.fight.turn.FightTurn;
import fr.quatrevieux.araknemu.game.fight.turn.Turn;
import fr.quatrevieux.araknemu.game.fight.turn.action.Action;
import fr.quatrevieux.araknemu.game.fight.turn.action.ActionResult;
Expand Down Expand Up @@ -60,7 +62,37 @@ public boolean validate(Turn<?> turn) {

@Override
public ActionResult start() {
return null;
return new ActionResult() {
@Override
public int action() {
return 0;
}

@Override
public PlayableFighter performer() {
return (PlayableFighter) ai.fighter();
}

@Override
public Object[] arguments() {
return new Object[0];
}

@Override
public boolean success() {
return false;
}

@Override
public boolean secret() {
return false;
}

@Override
public void apply(FightTurn turn) {

}
};
}
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import fr.quatrevieux.araknemu.game.fight.turn.action.util.BaseCriticalityStrategy;
import fr.quatrevieux.araknemu.game.spell.Spell;
import fr.quatrevieux.araknemu.game.spell.SpellService;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand All @@ -53,7 +55,7 @@ public void setUp() throws Exception {

fight = createFight();
fighter = player.fighter();
fight.register(new AiModule(new ChainAiFactory()));
fight.register(new AiModule(new ChainAiFactory(), fight, Mockito.mock(Logger.class)));
simulator = container.get(Simulator.class);
ai = new FighterAI(fighter, fight, new NullGenerator());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@
import fr.quatrevieux.araknemu.network.game.out.game.AddSprites;
import fr.quatrevieux.araknemu.network.game.out.game.UpdateCells;
import fr.quatrevieux.araknemu.network.game.out.info.Error;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.sql.SQLException;
import java.util.ArrayList;
Expand Down Expand Up @@ -110,7 +112,7 @@ public void setUp() throws Exception {
fight.register(new IndirectSpellApplyEffectsModule(fight, container.get(SpellService.class)));
fight.register(new MonsterInvocationModule(container.get(MonsterService.class), container.get(FighterFactory.class), fight));
fight.register(new SpiritualLeashModule(fight));
fight.register(new AiModule(container.get(AiFactory.class)));
fight.register(new AiModule(container.get(AiFactory.class), fight, Mockito.mock(Logger.class)));
fight.register(new FighterInitializationModule(container.get(GameConfiguration.class).fight()));

fighter1 = player.fighter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import fr.quatrevieux.araknemu.network.game.out.fight.action.ActionEffect;
import fr.quatrevieux.araknemu.network.game.out.fight.turn.FighterTurnOrder;
import fr.quatrevieux.araknemu.network.game.out.info.Error;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand All @@ -66,7 +67,7 @@ public void setUp() throws Exception {
super.setUp();

fight = createFight();
fight.register(new AiModule(container.get(AiFactory.class)));
fight.register(new AiModule(container.get(AiFactory.class), fight, Mockito.mock(Logger.class)));
fight.nextState();

caster = player.fighter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import fr.quatrevieux.araknemu.game.spell.Spell;
import fr.quatrevieux.araknemu.game.spell.effect.SpellEffect;
import fr.quatrevieux.araknemu.network.game.out.info.Error;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand All @@ -57,7 +58,7 @@ public void setUp() throws Exception {
dataSet.pushMonsterTemplateInvocations().pushMonsterSpellsInvocations().pushRewardItems();

fight = createFight();
fight.register(new AiModule(container.get(AiFactory.class)));
fight.register(new AiModule(container.get(AiFactory.class), fight, Mockito.mock(Logger.class)));
fight.nextState();

caster = player.fighter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import fr.quatrevieux.araknemu.game.spell.effect.target.SpellEffectTarget;
import fr.quatrevieux.araknemu.network.game.out.fight.action.ActionEffect;
import fr.quatrevieux.araknemu.network.game.out.fight.turn.FighterTurnOrder;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand Down Expand Up @@ -77,7 +78,7 @@ public void setUp() throws Exception {
.build(true)
;

fight.register(new AiModule(container.get(AiFactory.class)));
fight.register(new AiModule(container.get(AiFactory.class), fight, Mockito.mock(Logger.class)));
fight.nextState();

caster = player.fighter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import fr.quatrevieux.araknemu.network.game.out.fight.action.ActionEffect;
import fr.quatrevieux.araknemu.network.game.out.fight.turn.FighterTurnOrder;
import fr.quatrevieux.araknemu.network.game.out.info.Error;
import org.apache.logging.log4j.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand Down Expand Up @@ -70,7 +71,7 @@ public void setUp() throws Exception {
dataSet.pushMonsterTemplateInvocations().pushMonsterSpellsInvocations().pushRewardItems();

fight = createFight();
fight.register(new AiModule(container.get(AiFactory.class)));
fight.register(new AiModule(container.get(AiFactory.class), fight, Mockito.mock(Logger.class)));
fight.nextState();

caster = player.fighter();
Expand Down
Loading

0 comments on commit 2dbb4c6

Please sign in to comment.