forked from smarthi/rl4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Malmo MDP (pull #21)
- Loading branch information
Showing
16 changed files
with
682 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> | ||
<parent> | ||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>rl4j</artifactId> | ||
<version>0.9.2-SNAPSHOT</version> | ||
</parent> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<artifactId>rl4j-malmo</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<name>rl4j-malmo</name> | ||
|
||
<properties> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
</properties> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>rl4j-api</artifactId> | ||
<version>${project.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>com.microsoft.msr.malmo</groupId> | ||
<artifactId>MalmoJavaJar</artifactId> | ||
<version>0.30.0</version> | ||
</dependency> | ||
</dependencies> | ||
</project> |
40 changes: 40 additions & 0 deletions
40
rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package org.deeplearning4j.malmo; | ||
|
||
import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||
|
||
/** | ||
* Abstract base class for all Malmo-specific action spaces | ||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. | ||
*/ | ||
public abstract class MalmoActionSpace extends DiscreteSpace { | ||
/** | ||
* Array of action strings that will be sent to Malmo | ||
*/ | ||
protected String[] actions; | ||
|
||
/** | ||
* Protected constructor | ||
* @param size number of discrete actions in this space | ||
*/ | ||
protected MalmoActionSpace(int size) { | ||
super(size); | ||
} | ||
|
||
@Override | ||
public Object encode(Integer action) { | ||
return actions[action]; | ||
} | ||
|
||
@Override | ||
public Integer noOp() { | ||
return -1; | ||
} | ||
|
||
/** | ||
* Sets the seed used for random generation of actions | ||
* @param seed random number generator seed | ||
*/ | ||
public void setRandomSeed(long seed) { | ||
rd.setSeed(seed); | ||
} | ||
} |
16 changes: 16 additions & 0 deletions
16
rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package org.deeplearning4j.malmo; | ||
|
||
/** | ||
* Action space that allows for a fixed set of specific Malmo actions | ||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. | ||
*/ | ||
public class MalmoActionSpaceDiscrete extends MalmoActionSpace { | ||
/** | ||
* Construct an actions space from an array of action strings | ||
* @param actions Array of action strings | ||
*/ | ||
public MalmoActionSpaceDiscrete(String... actions) { | ||
super(actions.length); | ||
this.actions = actions; | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.deeplearning4j.malmo; | ||
|
||
import java.util.Arrays; | ||
|
||
import org.deeplearning4j.rl4j.space.Encodable; | ||
|
||
/** | ||
* Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor | ||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. | ||
*/ | ||
public class MalmoBox implements Encodable { | ||
double[] value; | ||
|
||
/** | ||
* Construct state from an array of doubles | ||
* @param value state values | ||
*/ | ||
//TODO: If this constructor was added to "Box", we wouldn't need this class at all. | ||
public MalmoBox(double... value) { | ||
this.value = value; | ||
} | ||
|
||
@Override | ||
public double[] toArray() { | ||
return value; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Arrays.toString(value); | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package org.deeplearning4j.malmo; | ||
|
||
/** | ||
* Exception thrown when Malmo cannot connect to a client after multiple retries | ||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. | ||
*/ | ||
public class MalmoConnectionError extends RuntimeException { | ||
private static final long serialVersionUID = -9034754802977073358L; | ||
|
||
public MalmoConnectionError(String string) { | ||
super(string); | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package org.deeplearning4j.malmo; | ||
|
||
import java.util.Arrays; | ||
|
||
import com.microsoft.msr.malmo.WorldState; | ||
|
||
/** | ||
* A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one. | ||
* This will only work for your mission if you require that every action moves to a new location. | ||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. | ||
*/ | ||
public class MalmoDescretePositionPolicy implements MalmoObservationPolicy { | ||
MalmoObservationSpacePosition observationSpace = new MalmoObservationSpacePosition(); | ||
|
||
@Override | ||
public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) { | ||
MalmoBox last_observation = observationSpace.getObservation(world_state); | ||
MalmoBox old_observation = observationSpace.getObservation(original_world_state); | ||
|
||
double[] newvalues = old_observation == null ? null : old_observation.toArray(); | ||
double[] oldvalues = last_observation == null ? null : last_observation.toArray(); | ||
|
||
return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty() | ||
|| Arrays.equals(oldvalues, newvalues)); | ||
} | ||
|
||
} |
Oops, something went wrong.