From a1308b9713a9baaddbbc18c450b61406b61587d0 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sat, 18 May 2019 15:48:24 +0200 Subject: [PATCH 1/2] Gradle version --- .gitignore | 30 +- build.gradle | 76 ++ proguard-rules.pro | 18 + settings.gradle | 1 + src/main/AndroidManifest.xml | 2 + .../AbstractActionSelectionStrategy.java | 96 ++- .../ActionSelectionStrategy.java | 12 +- .../ActionSelectionStrategyFactory.java | 96 ++- .../EpsilonGreedyActionSelectionStrategy.java | 113 ++- .../GibbsSoftMaxActionSelectionStrategy.java | 92 ++- .../GreedyActionSelectionStrategy.java | 27 +- .../SoftMaxActionSelectionStrategy.java | 39 +- .../actorcritic/ActorCriticAgent.java | 168 +++-- .../actorcritic/ActorCriticLambdaLearner.java | 224 +++--- .../actorcritic/ActorCriticLearner.java | 181 ++--- .../chen0040/rl/learning/qlearn/QAgent.java | 202 +++--- .../rl/learning/qlearn/QLambdaLearner.java | 238 ++++--- .../chen0040/rl/learning/qlearn/QLearner.java | 245 ++++--- .../chen0040/rl/learning/rlearn/RAgent.java | 175 +++-- .../chen0040/rl/learning/rlearn/RLearner.java | 238 +++---- .../rl/learning/sarsa/SarsaAgent.java | 235 ++++--- .../rl/learning/sarsa/SarsaLambdaLearner.java | 235 ++++--- .../rl/learning/sarsa/SarsaLearner.java | 217 +++--- .../rl/models/EligibilityTraceUpdateMode.java | 4 +- .../com/github/chen0040/rl/models/QModel.java | 292 ++++---- .../chen0040/rl/models/UtilityModel.java | 154 ++--- .../github/chen0040/rl/utils/DoubleUtils.java | 12 +- .../github/chen0040/rl/utils/IndexValue.java | 81 ++- .../com/github/chen0040/rl/utils/Matrix.java | 443 ++++++------ .../github/chen0040/rl/utils/MatrixUtils.java | 34 +- .../github/chen0040/rl/utils/TupleTwo.java | 100 +-- .../com/github/chen0040/rl/utils/Vec.java | 650 +++++++++--------- .../github/chen0040/rl/utils/VectorUtils.java | 49 +- .../actorcritic/ActorCriticAgentUnitTest.java | 47 +- .../ActorCriticLearnerUnitTest.java | 53 +- .../rl/learning/models/QModelUnitTest.java | 39 +- .../rl/learning/qlearn/QAgentUnitTest.java | 40 +- .../rl/learning/qlearn/QLearnerUnitTest.java | 67 +- .../rl/learning/rlearn/RAgentUnitTest.java | 46 +- .../rl/learning/sarsa/SarsaAgentUnitTest.java | 44 +- .../rl/learning/utils/MatrixUnitTest.java | 90 +-- .../rl/learning/utils/VecUnitTest.java | 24 +- src/test/resources/log4j.properties | 2 - 43 files changed, 2640 insertions(+), 2591 deletions(-) create mode 100644 build.gradle create mode 100644 proguard-rules.pro create mode 100644 settings.gradle create mode 100644 src/main/AndroidManifest.xml diff --git a/.gitignore b/.gitignore index ea24784..e33f790 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,18 @@ -.idea/ -*.iml - -*.class - -# Mobile Tools for Java (J2ME) -.mtj.tmp/ +# generated and compiled +bin/ +gen/ -# Package Files # -*.war -*.ear +# Gradle Build +.gradle/ +*/build/ +build/ -Thumbs.db - -target/ +# setting +.idea/ +local.properties +.classpath +.project -# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml -hs_err_pid* +#Android Studio +*.iml +*.iws diff --git a/build.gradle b/build.gradle new file mode 100644 index 0000000..0051418 --- /dev/null +++ b/build.gradle @@ -0,0 +1,76 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:3.5.0-beta01' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + } +} + +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + + defaultConfig { + minSdkVersion 24 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + + lintOptions { + abortOnError false + } + + buildTypes { + debug { + testCoverageEnabled false + } + } + + sourceSets { + main { + java { + // Merge source sets instead of adding rushcore as submodule so that the test coverage report works + srcDirs = ['src/main/java'] + } + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + testImplementation 'junit:junit:4.12' + implementation 'androidx.appcompat:appcompat:1.0.2' + + compileOnly 'org.projectlombok:lombok:1.18.8' + annotationProcessor 'org.projectlombok:lombok:1.18.8' + + implementation 'com.google.code.gson:gson:2.8.5' + implementation 'org.testng:testng:6.9.6' + implementation 'org.assertj:assertj-core:3.12.2' +} \ No newline at end of file diff --git a/proguard-rules.pro b/proguard-rules.pro new file mode 100644 index 0000000..92de838 --- /dev/null +++ b/proguard-rules.pro @@ -0,0 +1,18 @@ +# Add project specific ProGuard rules here. +# By default, the flags in this file are appended to flags specified +# in /Users/Stuart/Development/sdk/tools/proguard/proguard-android.txt +# You can edit the include path and order by changing the proguardFiles +# directive in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# Add any project specific keep options here: + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 0000000..0720936 --- /dev/null +++ b/settings.gradle @@ -0,0 +1 @@ +//include ':java-reinforcement-learning' diff --git a/src/main/AndroidManifest.xml b/src/main/AndroidManifest.xml new file mode 100644 index 0000000..f71a439 --- /dev/null +++ b/src/main/AndroidManifest.xml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java index 7de7f9b..92a5a55 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java @@ -8,66 +8,64 @@ import java.util.Map; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ public abstract class AbstractActionSelectionStrategy implements ActionSelectionStrategy { - private String prototype; - protected Map attributes = new HashMap(); - - public String getPrototype(){ - return prototype; - } + private String prototype; + protected Map attributes = new HashMap<>(); - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return new IndexValue(); - } + public String getPrototype() { + return prototype; + } - public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) { - return new IndexValue(); - } + public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + return new IndexValue(); + } - public AbstractActionSelectionStrategy(){ - prototype = this.getClass().getCanonicalName(); - } + public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) { + return new IndexValue(); + } + public AbstractActionSelectionStrategy() { + prototype = this.getClass().getCanonicalName(); + } - public AbstractActionSelectionStrategy(HashMap attributes){ - this.attributes = attributes; - if(attributes.containsKey("prototype")){ - this.prototype = attributes.get("prototype"); - } - } + public AbstractActionSelectionStrategy(HashMap attributes) { + this.attributes = attributes; + if (attributes.containsKey("prototype")) { + this.prototype = attributes.get("prototype"); + } + } - public Map getAttributes(){ - return attributes; - } + public Map getAttributes() { + return attributes; + } - @Override - public boolean equals(Object obj) { - ActionSelectionStrategy rhs = (ActionSelectionStrategy)obj; - if(!prototype.equalsIgnoreCase(rhs.getPrototype())) return false; - for(Map.Entry entry : rhs.getAttributes().entrySet()) { - if(!attributes.containsKey(entry.getKey())) { - return false; - } - if(!attributes.get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - for(Map.Entry entry : attributes.entrySet()) { - if(!rhs.getAttributes().containsKey(entry.getKey())) { - return false; - } - if(!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - return true; - } + @Override + public boolean equals(Object obj) { + ActionSelectionStrategy rhs = (ActionSelectionStrategy) obj; + if (!prototype.equalsIgnoreCase(rhs.getPrototype())) return false; + for (Map.Entry entry : rhs.getAttributes().entrySet()) { + if (!attributes.containsKey(entry.getKey())) { + return false; + } + if (!attributes.get(entry.getKey()).equals(entry.getValue())) { + return false; + } + } + for (Map.Entry entry : attributes.entrySet()) { + if (!rhs.getAttributes().containsKey(entry.getKey())) { + return false; + } + if (!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())) { + return false; + } + } + return true; + } - @Override - public abstract Object clone(); + @Override + public abstract Object clone(); } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java index 51b6824..ff92269 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java @@ -9,13 +9,15 @@ import java.util.Map; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ public interface ActionSelectionStrategy extends Serializable, Cloneable { - IndexValue selectAction(int stateId, QModel model, Set actionsAtState); - IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState); - String getPrototype(); - Map getAttributes(); + IndexValue selectAction(int stateId, QModel model, Set actionsAtState); + + IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState); + + String getPrototype(); + + Map getAttributes(); } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java index ce92be0..6159678 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java @@ -3,57 +3,55 @@ import java.util.HashMap; import java.util.Map; - /** * Created by xschen on 9/27/2015 0027. */ public class ActionSelectionStrategyFactory { - public static ActionSelectionStrategy deserialize(String conf){ - String[] comps = conf.split(";"); - - HashMap attributes = new HashMap(); - for(int i=0; i < comps.length; ++i){ - String comp = comps[i]; - String[] field = comp.split("="); - if(field.length < 2) continue; - String fieldname = field[0].trim(); - String fieldvalue = field[1].trim(); - - attributes.put(fieldname, fieldvalue); - } - if(attributes.isEmpty()){ - attributes.put("prototype", conf); - } - - String prototype = attributes.get("prototype"); - if(prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())){ - return new GreedyActionSelectionStrategy(); - } else if(prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())){ - return new SoftMaxActionSelectionStrategy(); - } else if(prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())){ - return new EpsilonGreedyActionSelectionStrategy(attributes); - } else if(prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())){ - return new GibbsSoftMaxActionSelectionStrategy(); - } - - return null; - } - - public static String serialize(ActionSelectionStrategy strategy){ - Map attributes = strategy.getAttributes(); - attributes.put("prototype", strategy.getPrototype()); - - StringBuilder sb = new StringBuilder(); - boolean first = true; - for(Map.Entry entry : attributes.entrySet()){ - if(first){ - first = false; - } - else{ - sb.append(";"); - } - sb.append(entry.getKey()+"="+entry.getValue()); - } - return sb.toString(); - } + public static ActionSelectionStrategy deserialize(String conf) { + String[] comps = conf.split(";"); + + HashMap attributes = new HashMap<>(); + for (String comp : comps) { + String[] field = comp.split("="); + if (field.length < 2) continue; + String fieldname = field[0].trim(); + String fieldvalue = field[1].trim(); + + attributes.put(fieldname, fieldvalue); + } + if (attributes.isEmpty()) { + attributes.put("prototype", conf); + } + + String prototype = attributes.get("prototype"); + if (prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())) { + return new GreedyActionSelectionStrategy(); + } else if (prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())) { + return new SoftMaxActionSelectionStrategy(); + } else if (prototype + .equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())) { + return new EpsilonGreedyActionSelectionStrategy(attributes); + } else if (prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())) { + return new GibbsSoftMaxActionSelectionStrategy(); + } + + return null; + } + + public static String serialize(ActionSelectionStrategy strategy) { + Map attributes = strategy.getAttributes(); + attributes.put("prototype", strategy.getPrototype()); + + StringBuilder sb = new StringBuilder(); + boolean first = true; + for (Map.Entry entry : attributes.entrySet()) { + if (first) { + first = false; + } else { + sb.append(";"); + } + sb.append(entry.getKey() + "=" + entry.getValue()); + } + return sb.toString(); + } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java index 5f7db9a..3d2e4e7 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java @@ -5,75 +5,74 @@ import java.util.*; - /** * Created by xschen on 9/27/2015 0027. */ public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { - public static final String EPSILON = "epsilon"; - private Random random = new Random(); + public static final String EPSILON = "epsilon"; + private Random random = new Random(); - @Override - public Object clone(){ - EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); - clone.copy(this); - return clone; - } + @Override + public Object clone() { + EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); + clone.copy(this); + return clone; + } - public void copy(EpsilonGreedyActionSelectionStrategy rhs){ - random = rhs.random; - for(Map.Entry entry : rhs.attributes.entrySet()){ - attributes.put(entry.getKey(), entry.getValue()); - } - } + public void copy(EpsilonGreedyActionSelectionStrategy rhs) { + random = rhs.random; + for (Map.Entry entry : rhs.attributes.entrySet()) { + attributes.put(entry.getKey(), entry.getValue()); + } + } - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){ - EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj; - if(epsilon() != rhs.epsilon()) return false; - // if(!random.equals(rhs.random)) return false; - return true; - } - return false; - } + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy) { + EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy) obj; + if (epsilon() != rhs.epsilon()) return false; + // if(!random.equals(rhs.random)) return false; + return true; + } + return false; + } - private double epsilon(){ - return Double.parseDouble(attributes.get(EPSILON)); - } + private double epsilon() { + return Double.parseDouble(attributes.get(EPSILON)); + } - public EpsilonGreedyActionSelectionStrategy(){ - epsilon(0.1); - } + public EpsilonGreedyActionSelectionStrategy() { + epsilon(0.1); + } - public EpsilonGreedyActionSelectionStrategy(HashMap attributes){ - super(attributes); - } + public EpsilonGreedyActionSelectionStrategy(HashMap attributes) { + super(attributes); + } - private void epsilon(double value){ - attributes.put(EPSILON, "" + value); - } + private void epsilon(double value) { + attributes.put(EPSILON, "" + value); + } - public EpsilonGreedyActionSelectionStrategy(Random random){ - this.random = random; - epsilon(0.1); - } + public EpsilonGreedyActionSelectionStrategy(Random random) { + this.random = random; + epsilon(0.1); + } - @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - if(random.nextDouble() < 1- epsilon()){ - return model.actionWithMaxQAtState(stateId, actionsAtState); - }else{ - int actionId; - if(actionsAtState != null && !actionsAtState.isEmpty()) { - List actions = new ArrayList<>(actionsAtState); - actionId = actions.get(random.nextInt(actions.size())); - } else { - actionId = random.nextInt(model.getActionCount()); - } + @Override + public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + if (random.nextDouble() < 1 - epsilon()) { + return model.actionWithMaxQAtState(stateId, actionsAtState); + } else { + int actionId; + if (actionsAtState != null && !actionsAtState.isEmpty()) { + List actions = new ArrayList<>(actionsAtState); + actionId = actions.get(random.nextInt(actions.size())); + } else { + actionId = random.nextInt(model.getActionCount()); + } - double Q = model.getQ(stateId, actionId); - return new IndexValue(actionId, Q); - } - } + double Q = model.getQ(stateId, actionId); + return new IndexValue(actionId, Q); + } + } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java index 8b2d8d2..12f1573 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java @@ -8,64 +8,62 @@ import java.util.Random; import java.util.Set; - /** * Created by xschen on 9/28/2015 0028. */ public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { - private Random random = null; - public GibbsSoftMaxActionSelectionStrategy(){ - random = new Random(); - } + private Random random = null; + + public GibbsSoftMaxActionSelectionStrategy() { + random = new Random(); + } - public GibbsSoftMaxActionSelectionStrategy(Random random){ - this.random = random; - } + public GibbsSoftMaxActionSelectionStrategy(Random random) { + this.random = random; + } - @Override - public Object clone() { - GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy(); - return clone; - } + @Override + public Object clone() { + GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy(); + return clone; + } - @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - List actions = new ArrayList(); - if(actionsAtState == null){ - for(int i=0; i < model.getActionCount(); ++i){ - actions.add(i); - } - }else{ - for(Integer actionId : actionsAtState){ - actions.add(actionId); - } - } + @Override + public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + List actions = new ArrayList<>(); + if (actionsAtState == null) { + for (int i = 0; i < model.getActionCount(); ++i) { + actions.add(i); + } + } else { + actions.addAll(actionsAtState); + } - double sum = 0; - List plist = new ArrayList(); - for(int i=0; i < actions.size(); ++i){ - int actionId = actions.get(i); - double p = Math.exp(model.getQ(stateId, actionId)); - sum += p; - plist.add(sum); - } + double sum = 0; + List plist = new ArrayList<>(); + for (int i = 0; i < actions.size(); ++i) { + int actionId = actions.get(i); + double p = Math.exp(model.getQ(stateId, actionId)); + sum += p; + plist.add(sum); + } - IndexValue iv = new IndexValue(); - iv.setIndex(-1); - iv.setValue(Double.NEGATIVE_INFINITY); + IndexValue iv = new IndexValue(); + iv.setIndex(-1); + iv.setValue(Double.NEGATIVE_INFINITY); - double r = sum * random.nextDouble(); - for(int i=0; i < actions.size(); ++i){ + double r = sum * random.nextDouble(); + for (int i = 0; i < actions.size(); ++i) { - if(plist.get(i) >= r){ - int actionId = actions.get(i); - iv.setValue(model.getQ(stateId, actionId)); - iv.setIndex(actionId); - break; - } - } + if (plist.get(i) >= r) { + int actionId = actions.get(i); + iv.setValue(model.getQ(stateId, actionId)); + iv.setIndex(actionId); + break; + } + } - return iv; - } + return iv; + } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java index 6b0f350..8d6d7f3 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java @@ -5,24 +5,23 @@ import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ public class GreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { - @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return model.actionWithMaxQAtState(stateId, actionsAtState); - } + @Override + public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + return model.actionWithMaxQAtState(stateId, actionsAtState); + } - @Override - public Object clone(){ - GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy(); - return clone; - } + @Override + public Object clone() { + GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy(); + return clone; + } - @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof GreedyActionSelectionStrategy; - } + @Override + public boolean equals(Object obj) { + return obj != null && obj instanceof GreedyActionSelectionStrategy; + } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java index f9735b9..51ef128 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java @@ -6,34 +6,33 @@ import java.util.Random; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ public class SoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { - private Random random = new Random(); + private Random random = new Random(); - @Override - public Object clone(){ - SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random); - return clone; - } + @Override + public Object clone() { + SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random); + return clone; + } - @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof SoftMaxActionSelectionStrategy; - } + @Override + public boolean equals(Object obj) { + return obj != null && obj instanceof SoftMaxActionSelectionStrategy; + } - public SoftMaxActionSelectionStrategy(){ + public SoftMaxActionSelectionStrategy() { - } + } - public SoftMaxActionSelectionStrategy(Random random){ - this.random = random; - } + public SoftMaxActionSelectionStrategy(Random random) { + this.random = random; + } - @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random); - } + @Override + public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random); + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java index 6e34874..f262afe 100644 --- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java @@ -7,95 +7,91 @@ import java.util.Set; import java.util.function.Function; - /** * Created by chen0469 on 9/28/2015 0028. */ public class ActorCriticAgent implements Serializable { - private ActorCriticLearner learner; - private int currentState; - private int prevState; - private int prevAction; - - public void enableEligibilityTrace(double lambda){ - ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; - } - - public void start(int stateId){ - currentState = stateId; - prevAction = -1; - prevState = -1; - } - - public ActorCriticLearner getLearner(){ - return learner; - } - - public void setLearner(ActorCriticLearner learner){ - this.learner = learner; - } - - public ActorCriticAgent(int stateCount, int actionCount){ - learner = new ActorCriticLearner(stateCount, actionCount); - } - - public ActorCriticAgent(){ - - } - - public ActorCriticAgent(ActorCriticLearner learner){ - this.learner = learner; - } - - public ActorCriticAgent makeCopy(){ - ActorCriticAgent clone = new ActorCriticAgent(); - clone.copy(this); - return clone; - } - - public void copy(ActorCriticAgent rhs){ - learner = (ActorCriticLearner)rhs.learner.makeCopy(); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof ActorCriticAgent){ - ActorCriticAgent rhs = (ActorCriticAgent)obj; - return learner.equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState; - - } - return false; - } - - public int selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); - } - - public int selectAction(){ - return learner.selectAction(currentState); - } - - public void update(int actionTaken, int newState, double immediateReward, final Vec V){ - update(actionTaken, newState, null, immediateReward, V); - } - - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V){ - - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, new Function() { - public Double apply(Integer stateId) { - return V.get(stateId); - } - }); - - prevAction = actionTaken; - prevState = currentState; - - currentState = newState; - } + private ActorCriticLearner learner; + private int currentState; + private int prevState; + private int prevAction; + + public void enableEligibilityTrace(double lambda) { + ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner); + acll.setLambda(lambda); + learner = acll; + } + + public void start(int stateId) { + currentState = stateId; + prevAction = -1; + prevState = -1; + } + + public ActorCriticLearner getLearner() { + return learner; + } + + public void setLearner(ActorCriticLearner learner) { + this.learner = learner; + } + + public ActorCriticAgent(int stateCount, int actionCount) { + learner = new ActorCriticLearner(stateCount, actionCount); + } + + public ActorCriticAgent() { + + } + + public ActorCriticAgent(ActorCriticLearner learner) { + this.learner = learner; + } + + public ActorCriticAgent makeCopy() { + ActorCriticAgent clone = new ActorCriticAgent(); + clone.copy(this); + return clone; + } + + public void copy(ActorCriticAgent rhs) { + learner = (ActorCriticLearner) rhs.learner.makeCopy(); + prevAction = rhs.prevAction; + prevState = rhs.prevState; + currentState = rhs.currentState; + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof ActorCriticAgent) { + ActorCriticAgent rhs = (ActorCriticAgent) obj; + return learner + .equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState; + + } + return false; + } + + public int selectAction(Set actionsAtState) { + return learner.selectAction(currentState, actionsAtState); + } + + public int selectAction() { + return learner.selectAction(currentState); + } + + public void update(int actionTaken, int newState, double immediateReward, final Vec V) { + update(actionTaken, newState, null, immediateReward, V); + } + + public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V) { + + learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, V::get); + + prevAction = actionTaken; + prevState = currentState; + + currentState = newState; + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java index d68f978..1ac041d 100644 --- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java @@ -1,129 +1,123 @@ package com.github.chen0040.rl.learning.actorcritic; - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; import java.util.Set; import java.util.function.Function; - /** * Created by chen0469 on 9/28/2015 0028. */ public class ActorCriticLambdaLearner extends ActorCriticLearner { - private Matrix e; - private double lambda = 0.9; - private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; - - public ActorCriticLambdaLearner(){ - super(); - } - - public ActorCriticLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - - - public ActorCriticLambdaLearner(ActorCriticLearner learner){ - copy(learner); - e = new Matrix(P.getStateCount(), P.getActionCount()); - } - - public ActorCriticLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double lambda, double initialP){ - super(stateCount, actionCount, alpha, gamma, initialP); - this.lambda = lambda; - e = new Matrix(stateCount, actionCount); - } - - public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; - } - - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { - this.traceUpdateMode = traceUpdateMode; - } - - public double getLambda(){ - return lambda; - } - - public void setLambda(double lambda){ - this.lambda = lambda; - } - - public ActorCriticLambdaLearner makeCopy(){ - ActorCriticLambdaLearner clone = new ActorCriticLambdaLearner(); - clone.copy(this); - return clone; - } - - @Override - public void copy(ActorCriticLearner rhs){ - super.copy(rhs); - - ActorCriticLambdaLearner rhs2 = (ActorCriticLambdaLearner)rhs; - e = rhs2.e.makeCopy(); - lambda = rhs2.lambda; - traceUpdateMode = rhs2.traceUpdateMode; - } - - @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ - return false; - } - - if(obj instanceof ActorCriticLambdaLearner){ - ActorCriticLambdaLearner rhs = (ActorCriticLambdaLearner)obj; - return e.equals(rhs.e) && lambda == rhs.lambda && traceUpdateMode == rhs.traceUpdateMode; - } - - return false; - } - - public Matrix getEligibility(){ - return e; - } - - public void setEligibility(Matrix e){ - this.e = e; - } - - @Override - public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V){ - - double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId); - - int stateCount = P.getStateCount(); - int actionCount = P.getActionCount(); - - double gamma = P.getGamma(); - - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); - - - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - - double oldP = P.getQ(stateId, actionId); - double alpha = P.getAlpha(currentStateId, currentActionId); - double newP = oldP + alpha * td_error * e.get(stateId, actionId); - - P.setQ(stateId, actionId, newP); - - if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); - } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); - } - } - } - } - + private Matrix e; + private double lambda = 0.9; + private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + + public ActorCriticLambdaLearner() { + super(); + } + + public ActorCriticLambdaLearner(int stateCount, int actionCount) { + super(stateCount, actionCount); + e = new Matrix(stateCount, actionCount); + } + + public ActorCriticLambdaLearner(ActorCriticLearner learner) { + copy(learner); + e = new Matrix(P.getStateCount(), P.getActionCount()); + } + + public ActorCriticLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double lambda, double initialP) { + super(stateCount, actionCount, alpha, gamma, initialP); + this.lambda = lambda; + e = new Matrix(stateCount, actionCount); + } + + public EligibilityTraceUpdateMode getTraceUpdateMode() { + return traceUpdateMode; + } + + public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + this.traceUpdateMode = traceUpdateMode; + } + + public double getLambda() { + return lambda; + } + + public void setLambda(double lambda) { + this.lambda = lambda; + } + + public ActorCriticLambdaLearner makeCopy() { + ActorCriticLambdaLearner clone = new ActorCriticLambdaLearner(); + clone.copy(this); + return clone; + } + + @Override + public void copy(ActorCriticLearner rhs) { + super.copy(rhs); + + ActorCriticLambdaLearner rhs2 = (ActorCriticLambdaLearner) rhs; + e = rhs2.e.makeCopy(); + lambda = rhs2.lambda; + traceUpdateMode = rhs2.traceUpdateMode; + } + + @Override + public boolean equals(Object obj) { + if (!super.equals(obj)) { + return false; + } + + if (obj instanceof ActorCriticLambdaLearner) { + ActorCriticLambdaLearner rhs = (ActorCriticLambdaLearner) obj; + return e.equals(rhs.e) && lambda == rhs.lambda && traceUpdateMode == rhs.traceUpdateMode; + } + + return false; + } + + public Matrix getEligibility() { + return e; + } + + public void setEligibility(Matrix e) { + this.e = e; + } + + @Override + public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V) { + + double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId); + + int stateCount = P.getStateCount(); + int actionCount = P.getActionCount(); + + double gamma = P.getGamma(); + + e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + + double oldP = P.getQ(stateId, actionId); + double alpha = P.getAlpha(currentStateId, currentActionId); + double newP = oldP + alpha * td_error * e.get(stateId, actionId); + + P.setQ(stateId, actionId, newP); + + if (actionId != currentActionId) { + e.set(currentStateId, actionId, 0); + } else { + e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + } + } + } + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java index d106c94..2a25094 100644 --- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java @@ -1,8 +1,7 @@ package com.github.chen0040.rl.learning.actorcritic; - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; @@ -10,100 +9,102 @@ import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Vec; +import com.google.gson.Gson; import java.io.Serializable; import java.util.Random; import java.util.Set; import java.util.function.Function; - /** * Created by chen0469 on 9/28/2015 0028. */ -public class ActorCriticLearner implements Serializable{ - protected QModel P; - protected ActionSelectionStrategy actionSelectionStrategy; - - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } - - public static ActorCriticLearner fromJson(String json){ - return JSON.parseObject(json, ActorCriticLearner.class); - } - - public Object makeCopy(){ - ActorCriticLearner clone = new ActorCriticLearner(); - clone.copy(this); - return clone; - } - - public void copy(ActorCriticLearner rhs){ - P = rhs.P.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone(); - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof ActorCriticLearner){ - ActorCriticLearner rhs = (ActorCriticLearner)obj; - return P.equals(rhs.P) && getActionSelection().equals(rhs.getActionSelection()); - } - return false; - } - - public ActorCriticLearner(){ - - } - - public ActorCriticLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 1, 0.7, 0.01); - } - - public int selectAction(int stateId, Set actionsAtState){ - IndexValue iv = actionSelectionStrategy.selectAction(stateId, P, actionsAtState); - return iv.getIndex(); - } - - public int selectAction(int stateId){ - return selectAction(stateId, null); - } - - public ActorCriticLearner(int stateCount, int actionCount, double beta, double gamma, double initialP){ - P = new QModel(stateCount, actionCount, initialP); - P.setAlpha(beta); - P.setGamma(gamma); - - actionSelectionStrategy = new GibbsSoftMaxActionSelectionStrategy(); - } - - public void update(int currentStateId, int currentActionId, int newStateId, double immediateReward, Function V){ - update(currentStateId, currentActionId, newStateId, null, immediateReward, V); - } - - public void update(int currentStateId, int currentActionId, int newStateId,Set actionsAtNewState, double immediateReward, Function V){ - double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId); - - double oldP = P.getQ(currentStateId, currentActionId); - double beta = P.getAlpha(currentStateId, currentActionId); - double newP = oldP + beta * td_error; - P.setQ(currentStateId, currentActionId, newP); - } - - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); - } - - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); - } - - - public QModel getP() { - return P; - } - - public void setP(QModel p) { - P = p; - } +public class ActorCriticLearner implements Serializable { + protected QModel P; + protected ActionSelectionStrategy actionSelectionStrategy; + + public String toJson() { + return new Gson().toJson(this); +// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public static ActorCriticLearner fromJson(String json) { + return new Gson().fromJson(json, ActorCriticLearner.class); +// return JSON.parseObject(json, ActorCriticLearner.class); + } + + public Object makeCopy() { + ActorCriticLearner clone = new ActorCriticLearner(); + clone.copy(this); + return clone; + } + + public void copy(ActorCriticLearner rhs) { + P = rhs.P.makeCopy(); + actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy) + .clone(); + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof ActorCriticLearner) { + ActorCriticLearner rhs = (ActorCriticLearner) obj; + return P.equals(rhs.P) && getActionSelection().equals(rhs.getActionSelection()); + } + return false; + } + + public ActorCriticLearner() { + + } + + public ActorCriticLearner(int stateCount, int actionCount) { + this(stateCount, actionCount, 1, 0.7, 0.01); + } + + public int selectAction(int stateId, Set actionsAtState) { + IndexValue iv = actionSelectionStrategy.selectAction(stateId, P, actionsAtState); + return iv.getIndex(); + } + + public int selectAction(int stateId) { + return selectAction(stateId, null); + } + + public ActorCriticLearner(int stateCount, int actionCount, double beta, double gamma, double initialP) { + P = new QModel(stateCount, actionCount, initialP); + P.setAlpha(beta); + P.setGamma(gamma); + + actionSelectionStrategy = new GibbsSoftMaxActionSelectionStrategy(); + } + + public void update(int currentStateId, int currentActionId, int newStateId, double immediateReward, Function V) { + update(currentStateId, currentActionId, newStateId, null, immediateReward, V); + } + + public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V) { + double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId); + + double oldP = P.getQ(currentStateId, currentActionId); + double beta = P.getAlpha(currentStateId, currentActionId); + double newP = oldP + beta * td_error; + P.setQ(currentStateId, currentActionId, newP); + } + + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + } + + public void setActionSelection(String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } + + public QModel getP() { + return P; + } + + public void setP(QModel p) { + P = p; + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java index afdb314..8b1a84e 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java @@ -6,107 +6,109 @@ import java.util.Random; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ -public class QAgent implements Serializable{ - private QLearner learner; - private int currentState; - private int prevState; - - /** action taken at prevState */ - private int prevAction; - - public int getCurrentState(){ - return currentState; - } - - public int getPrevState(){ - return prevState; - } - - public int getPrevAction(){ - return prevAction; - } - - public void start(int currentState){ - this.currentState = currentState; - this.prevAction = -1; - this.prevState = -1; - } - - public IndexValue selectAction(){ - return learner.selectAction(currentState); - } - - public IndexValue selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); - } - - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); - } - - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ - - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward); - - prevState = currentState; - prevAction = actionTaken; - - currentState = newState; - } - - public void enableEligibilityTrace(double lambda){ - QLambdaLearner acll = new QLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; - } - - public QLearner getLearner(){ - return learner; - } - - public void setLearner(QLearner learner){ - this.learner = learner; - } - - public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); - } - - public QAgent(QLearner learner){ - this.learner = learner; - } - - public QAgent(int stateCount, int actionCount){ - learner = new QLearner(stateCount, actionCount); - } - - public QAgent(){ - - } - - public QAgent makeCopy(){ - QAgent clone = new QAgent(); - clone.copy(this); - return clone; - } - - public void copy(QAgent rhs){ - learner.copy(rhs.learner); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof QAgent){ - QAgent rhs = (QAgent)obj; - return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner); - } - return false; - } +public class QAgent implements Serializable { + private QLearner learner; + private int currentState; + private int prevState; + + /** + * action taken at prevState + */ + private int prevAction; + + public int getCurrentState() { + return currentState; + } + + public int getPrevState() { + return prevState; + } + + public int getPrevAction() { + return prevAction; + } + + public void start(int currentState) { + this.currentState = currentState; + this.prevAction = -1; + this.prevState = -1; + } + + public IndexValue selectAction() { + return learner.selectAction(currentState); + } + + public IndexValue selectAction(Set actionsAtState) { + return learner.selectAction(currentState, actionsAtState); + } + + public void update(int actionTaken, int newState, double immediateReward) { + update(actionTaken, newState, null, immediateReward); + } + + public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward) { + + learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward); + + prevState = currentState; + prevAction = actionTaken; + + currentState = newState; + } + + public void enableEligibilityTrace(double lambda) { + QLambdaLearner acll = new QLambdaLearner(learner); + acll.setLambda(lambda); + learner = acll; + } + + public QLearner getLearner() { + return learner; + } + + public void setLearner(QLearner learner) { + this.learner = learner; + } + + public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); + } + + public QAgent(QLearner learner) { + this.learner = learner; + } + + public QAgent(int stateCount, int actionCount) { + learner = new QLearner(stateCount, actionCount); + } + + public QAgent() { + + } + + public QAgent makeCopy() { + QAgent clone = new QAgent(); + clone.copy(this); + return clone; + } + + public void copy(QAgent rhs) { + learner.copy(rhs.learner); + prevAction = rhs.prevAction; + prevState = rhs.prevState; + currentState = rhs.currentState; + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof QAgent) { + QAgent rhs = (QAgent) obj; + return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner + .equals(rhs.learner); + } + return false; + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java index 875ef3a..df75cd2 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java @@ -1,135 +1,129 @@ package com.github.chen0040.rl.learning.qlearn; - import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; import java.util.Set; - /** * Created by xschen on 9/28/2015 0028. */ public class QLambdaLearner extends QLearner { - private double lambda = 0.9; - private Matrix e; - private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; - - public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; - } - - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { - this.traceUpdateMode = traceUpdateMode; - } - - public double getLambda(){ - return lambda; - } - - public void setLambda(double lambda){ - this.lambda = lambda; - } - - public QLambdaLearner makeCopy(){ - QLambdaLearner clone = new QLambdaLearner(); - clone.copy(this); - return clone; - } - - @Override - public void copy(QLearner rhs){ - super.copy(rhs); - - QLambdaLearner rhs2 = (QLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; - } - - public QLambdaLearner(QLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); - } - - @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ - return false; - } - - if(obj instanceof QLambdaLearner){ - QLambdaLearner rhs = (QLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; - } - - return false; - } - - public QLambdaLearner(){ - super(); - } - - public QLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); - } - - public Matrix getEligibility() - { - return e; - } - - public void setEligibility(Matrix e){ - this.e = e; - } - - @Override - public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { - // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); - - // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); - - // discount_rate; - double gamma = model.getGamma(); - - // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); - - double td_error = immediateReward + gamma * maxQ - oldQ; - - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); - - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); - - - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); - - // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(currentStateId, currentActionId, newQ); - - if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); - } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); - } - } - } - - - - } + private double lambda = 0.9; + private Matrix e; + private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + + public EligibilityTraceUpdateMode getTraceUpdateMode() { + return traceUpdateMode; + } + + public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + this.traceUpdateMode = traceUpdateMode; + } + + public double getLambda() { + return lambda; + } + + public void setLambda(double lambda) { + this.lambda = lambda; + } + + public QLambdaLearner makeCopy() { + QLambdaLearner clone = new QLambdaLearner(); + clone.copy(this); + return clone; + } + + @Override + public void copy(QLearner rhs) { + super.copy(rhs); + + QLambdaLearner rhs2 = (QLambdaLearner) rhs; + lambda = rhs2.lambda; + e = rhs2.e.makeCopy(); + traceUpdateMode = rhs2.traceUpdateMode; + } + + public QLambdaLearner(QLearner learner) { + copy(learner); + e = new Matrix(model.getStateCount(), model.getActionCount()); + } + + @Override + public boolean equals(Object obj) { + if (!super.equals(obj)) { + return false; + } + + if (obj instanceof QLambdaLearner) { + QLambdaLearner rhs = (QLambdaLearner) obj; + return rhs.lambda == lambda && e + .equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + } + + return false; + } + + public QLambdaLearner() { + super(); + } + + public QLambdaLearner(int stateCount, int actionCount) { + super(stateCount, actionCount); + e = new Matrix(stateCount, actionCount); + } + + public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + e = new Matrix(stateCount, actionCount); + } + + public Matrix getEligibility() { + return e; + } + + public void setEligibility(Matrix e) { + this.e = e; + } + + @Override + public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) { + // old_value is $Q_t(s_t, a_t)$ + double oldQ = model.getQ(currentStateId, currentActionId); + + // learning_rate; + double alpha = model.getAlpha(currentStateId, currentActionId); + + // discount_rate; + double gamma = model.getGamma(); + + // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ + double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); + + double td_error = immediateReward + gamma * maxQ - oldQ; + + int stateCount = model.getStateCount(); + int actionCount = model.getActionCount(); + + e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = model.getQ(stateId, actionId); + double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + + // new_value is $Q_{t+1}(s_t, a_t)$ + model.setQ(currentStateId, currentActionId, newQ); + + if (actionId != currentActionId) { + e.set(currentStateId, actionId, 0); + } else { + e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + } + } + } + + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java index 865abc5..a087957 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java @@ -1,142 +1,137 @@ package com.github.chen0040.rl.learning.qlearn; - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.annotation.JSONField; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.annotation.JSONField; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; +import com.google.gson.Gson; import java.io.Serializable; import java.util.Random; import java.util.Set; - /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is + * an off-policy TD control algorithm Q is known as the quality of state-action combination, note + * that it is different from utility of a state */ -public class QLearner implements Serializable,Cloneable { - protected QModel model; - - private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - - public QLearner makeCopy(){ - QLearner clone = new QLearner(); - clone.copy(this); - return clone; - } - - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } - - public static QLearner fromJson(String json){ - return JSON.parseObject(json, QLearner.class); - } - - public void copy(QLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); - } - - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof QLearner){ - QLearner rhs = (QLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; - } - - public QModel getModel() { - return model; - } - - public void setModel(QModel model) { - this.model = model; - } - - - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); - } - - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); - } - - public QLearner(){ - - } - - public QLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.7, 0.1); - } - - public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ - this.model = model; - this.actionSelectionStrategy = actionSelectionStrategy; - } - - public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - } - - - protected double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; - } - - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); - } - - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); - } - - - public void update(int stateId, int actionId, int nextStateId, double immediateReward){ - update(stateId, actionId, nextStateId, null, immediateReward); - } - - public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { - // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); - - // learning_rate; - double alpha = model.getAlpha(stateId, actionId); - - // discount_rate; - double gamma = model.getGamma(); - - // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); - - // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value - // old_value = oldQ - // temporal_difference = learned_value - old_value - // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); - - // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); - } - - +public class QLearner implements Serializable, Cloneable { + protected QModel model; + + private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + + public QLearner makeCopy() { + QLearner clone = new QLearner(); + clone.copy(this); + return clone; + } + + public String toJson() { + return new Gson().toJson(this); +// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public static QLearner fromJson(String json) { + return new Gson().fromJson(json, QLearner.class); +// return JSON.parseObject(json, QLearner.class); + } + + public void copy(QLearner rhs) { + model = rhs.model.makeCopy(); + actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy) + .clone(); + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof QLearner) { + QLearner rhs = (QLearner) obj; + if (!model.equals(rhs.model)) return false; + return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; + } + + public QModel getModel() { + return model; + } + + public void setModel(QModel model) { + this.model = model; + } + + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + } + + public void setActionSelection(String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } + + public QLearner() { + + } + + public QLearner(int stateCount, int actionCount) { + this(stateCount, actionCount, 0.1, 0.7, 0.1); + } + + public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy) { + this.model = model; + this.actionSelectionStrategy = actionSelectionStrategy; + } + + public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + model = new QModel(stateCount, actionCount, initialQ); + model.setAlpha(alpha); + model.setGamma(gamma); + actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + } + + protected double maxQAtState(int stateId, Set actionsAtState) { + IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); + double maxQ = iv.getValue(); + return maxQ; + } + + public IndexValue selectAction(int stateId, Set actionsAtState) { + return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + } + + public IndexValue selectAction(int stateId) { + return selectAction(stateId, null); + } + + public void update(int stateId, int actionId, int nextStateId, double immediateReward) { + update(stateId, actionId, nextStateId, null, immediateReward); + } + + public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) { + // old_value is $Q_t(s_t, a_t)$ + double oldQ = model.getQ(stateId, actionId); + + // learning_rate; + double alpha = model.getAlpha(stateId, actionId); + + // discount_rate; + double gamma = model.getGamma(); + + // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ + double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); + + // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value + // old_value = oldQ + // temporal_difference = learned_value - old_value + // new_value = old_value + learning_rate * temporal_difference + double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); + + // new_value is $Q_{t+1}(s_t, a_t)$ + model.setQ(stateId, actionId, newQ); + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java index f26f20a..533273c 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java @@ -6,96 +6,93 @@ import java.util.Random; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ -public class RAgent implements Serializable{ - private RLearner learner; - private int currentState; - private int currentAction; - private double currentValue; - - public int getCurrentState(){ - return currentState; - } - - public int getCurrentAction(){ - return currentAction; - } - - public void start(int currentState){ - this.currentState = currentState; - } - - public RAgent makeCopy(){ - RAgent clone = new RAgent(); - clone.copy(this); - return clone; - } - - public void copy(RAgent rhs){ - currentState = rhs.currentState; - currentAction = rhs.currentAction; - learner.copy(rhs.learner); - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RAgent){ - RAgent rhs = (RAgent)obj; - if(!learner.equals(rhs.learner)) return false; - if(currentAction != rhs.currentAction) return false; - return currentState == rhs.currentState; - } - return false; - } - - public IndexValue selectAction(){ - return selectAction(null); - } - - public IndexValue selectAction(Set actionsAtState){ - - if(currentAction==-1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); - } - return new IndexValue(currentAction, currentValue); - } - - public void update(int newState, double immediateReward){ - update(newState, null, immediateReward); - } - - public void update(int newState, Set actionsAtState, double immediateReward){ - if(currentAction != -1) { - learner.update(currentState, currentAction, newState, actionsAtState, immediateReward); - currentState = newState; - currentAction = -1; - } - } - - public RAgent(){ - - } - - - - public RLearner getLearner(){ - return learner; - } - - public void setLearner(RLearner learner){ - this.learner = learner; - } - - public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ){ - learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); - } - - public RAgent(int stateCount, int actionCount){ - learner = new RLearner(stateCount, actionCount); - } +public class RAgent implements Serializable { + private RLearner learner; + private int currentState; + private int currentAction; + private double currentValue; + + public int getCurrentState() { + return currentState; + } + + public int getCurrentAction() { + return currentAction; + } + + public void start(int currentState) { + this.currentState = currentState; + } + + public RAgent makeCopy() { + RAgent clone = new RAgent(); + clone.copy(this); + return clone; + } + + public void copy(RAgent rhs) { + currentState = rhs.currentState; + currentAction = rhs.currentAction; + learner.copy(rhs.learner); + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof RAgent) { + RAgent rhs = (RAgent) obj; + if (!learner.equals(rhs.learner)) return false; + if (currentAction != rhs.currentAction) return false; + return currentState == rhs.currentState; + } + return false; + } + + public IndexValue selectAction() { + return selectAction(null); + } + + public IndexValue selectAction(Set actionsAtState) { + + if (currentAction == -1) { + IndexValue iv = learner.selectAction(currentState, actionsAtState); + currentAction = iv.getIndex(); + currentValue = iv.getValue(); + } + return new IndexValue(currentAction, currentValue); + } + + public void update(int newState, double immediateReward) { + update(newState, null, immediateReward); + } + + public void update(int newState, Set actionsAtState, double immediateReward) { + if (currentAction != -1) { + learner.update(currentState, currentAction, newState, actionsAtState, immediateReward); + currentState = newState; + currentAction = -1; + } + } + + public RAgent() { + + } + + public RLearner getLearner() { + return learner; + } + + public void setLearner(RLearner learner) { + this.learner = learner; + } + + public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ) { + learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); + } + + public RAgent(int stateCount, int actionCount) { + learner = new RLearner(stateCount, actionCount); + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java index 910d53f..40dea5a 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java @@ -1,141 +1,141 @@ package com.github.chen0040.rl.learning.rlearn; - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; +import com.google.gson.Gson; + import lombok.Getter; import java.io.Serializable; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ -public class RLearner implements Serializable, Cloneable{ - - private QModel model; - private ActionSelectionStrategy actionSelectionStrategy; - private double rho; - private double beta; - - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } +public class RLearner implements Serializable, Cloneable { + + private QModel model; + private ActionSelectionStrategy actionSelectionStrategy; + private double rho; + private double beta; - public static RLearner fromJson(String json){ - return JSON.parseObject(json, RLearner.class); - } + public String toJson() { + return new Gson().toJson(this); +// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public static RLearner fromJson(String json) { + return new Gson().fromJson(json, RLearner.class); +// return JSON.parseObject(json, RLearner.class); + } - public RLearner makeCopy(){ - RLearner clone = new RLearner(); - clone.copy(this); - return clone; - } - - public void copy(RLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone(); - rho = rhs.rho; - beta = rhs.beta; - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RLearner){ - RLearner rhs = (RLearner)obj; - if(!model.equals(rhs.model)) return false; - if(!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false; - if(rho != rhs.rho) return false; - return beta == rhs.beta; - } - return false; - } - - public RLearner(){ - - } - - public double getRho() { - return rho; - } - - public void setRho(double rho) { - this.rho = rho; - } - - public double getBeta() { - return beta; - } - - public void setBeta(double beta) { - this.beta = beta; - } - - public QModel getModel(){ - return model; - - } - - public void setModel(QModel model){ - this.model = model; - } - - public String getActionSelection(){ - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); - } - - public void setActionSelection(String conf){ - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); - } - - public RLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); - } - - public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q) - { - model = new QModel(state_count, action_count, initial_Q); - model.setAlpha(alpha); - - this.rho = rho; - this.beta = beta; - - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - } - - private double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; - } - - public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward) - { - double oldQ = model.getQ(currentState, actionTaken); + public RLearner makeCopy() { + RLearner clone = new RLearner(); + clone.copy(this); + return clone; + } + + public void copy(RLearner rhs) { + model = rhs.model.makeCopy(); + actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy) + .clone(); + rho = rhs.rho; + beta = rhs.beta; + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof RLearner) { + RLearner rhs = (RLearner) obj; + if (!model.equals(rhs.model)) return false; + if (!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false; + if (rho != rhs.rho) return false; + return beta == rhs.beta; + } + return false; + } + + public RLearner() { + + } + + public double getRho() { + return rho; + } + + public void setRho(double rho) { + this.rho = rho; + } + + public double getBeta() { + return beta; + } + + public void setBeta(double beta) { + this.beta = beta; + } + + public QModel getModel() { + return model; + + } + + public void setModel(QModel model) { + this.model = model; + } + + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + } + + public void setActionSelection(String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } + + public RLearner(int stateCount, int actionCount) { + this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); + } + + public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q) { + model = new QModel(state_count, action_count, initial_Q); + model.setAlpha(alpha); + + this.rho = rho; + this.beta = beta; + + actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + } + + private double maxQAtState(int stateId, Set actionsAtState) { + IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); + double maxQ = iv.getValue(); + return maxQ; + } + + public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward) { + double oldQ = model.getQ(currentState, actionTaken); + + double alpha = model.getAlpha(currentState, actionTaken); // learning rate; + + double maxQ = maxQAtState(newState, actionsAtNextStateId); - double alpha = model.getAlpha(currentState, actionTaken); // learning rate; - - double maxQ = maxQAtState(newState, actionsAtNextStateId); + double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ); + + double maxQAtCurrentState = maxQAtState(currentState, null); + if (newQ == maxQAtCurrentState) { + rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState); + } - double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ); - - double maxQAtCurrentState = maxQAtState(currentState, null); - if (newQ == maxQAtCurrentState) - { - rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState); - } + model.setQ(currentState, actionTaken, newQ); + } - model.setQ(currentState, actionTaken, newQ); - } - - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); - } + public IndexValue selectAction(int stateId, Set actionsAtState) { + return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java index c4c8f27..2cfa8ae 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java @@ -6,125 +6,122 @@ import java.util.Random; import java.util.Set; - /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Sarsa, which is an on-policy TD control algorithm + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Sarsa, which is an + * on-policy TD control algorithm */ -public class SarsaAgent implements Serializable{ - private SarsaLearner learner; - private int currentState; - private int currentAction; - private double currentValue; - private int prevState; - private int prevAction; - - public int getCurrentState(){ - return currentState; - } - - public int getCurrentAction(){ - return currentAction; - } - - public int getPrevState() { return prevState; } - - public int getPrevAction() { return prevAction; } - - public void start(int currentState){ - this.currentState = currentState; - this.prevState = -1; - this.prevAction = -1; - } - - public IndexValue selectAction(){ - return selectAction(null); - } - - public IndexValue selectAction(Set actionsAtState){ - if(currentAction == -1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); - } - - return new IndexValue(currentAction, currentValue); - } - - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); - } - - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ - - IndexValue iv = learner.selectAction(currentState, actionsAtNewState); - int futureAction = iv.getIndex(); - - learner.update(currentState, actionTaken, newState, futureAction, immediateReward); - - prevState = this.currentState; - this.prevAction = actionTaken; - - currentAction = futureAction; - currentState = newState; - } - - - - public SarsaLearner getLearner(){ - return learner; - } - - public void setLearner(SarsaLearner learner){ - this.learner = learner; - } - - public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); - } - - public SarsaAgent(int stateCount, int actionCount){ - learner = new SarsaLearner(stateCount, actionCount); - } - - public SarsaAgent(SarsaLearner learner){ - this.learner = learner; - } - - public SarsaAgent(){ - - } - - public void enableEligibilityTrace(double lambda){ - SarsaLambdaLearner acll = new SarsaLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; - } - - public SarsaAgent makeCopy(){ - SarsaAgent clone = new SarsaAgent(); - clone.copy(this); - return clone; - } - - public void copy(SarsaAgent rhs){ - learner.copy(rhs.learner); - currentAction = rhs.currentAction; - currentState = rhs.currentState; - prevAction = rhs.prevAction; - prevState = rhs.prevState; - } - - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof SarsaAgent){ - SarsaAgent rhs = (SarsaAgent)obj; - return prevAction == rhs.prevAction - && prevState == rhs.prevState - && currentAction == rhs.currentAction - && currentState == rhs.currentState - && learner.equals(rhs.learner); - } - return false; - } +public class SarsaAgent implements Serializable { + private SarsaLearner learner; + private int currentState; + private int currentAction; + private double currentValue; + private int prevState; + private int prevAction; + + public int getCurrentState() { + return currentState; + } + + public int getCurrentAction() { + return currentAction; + } + + public int getPrevState() { return prevState; } + + public int getPrevAction() { return prevAction; } + + public void start(int currentState) { + this.currentState = currentState; + this.prevState = -1; + this.prevAction = -1; + } + + public IndexValue selectAction() { + return selectAction(null); + } + + public IndexValue selectAction(Set actionsAtState) { + if (currentAction == -1) { + IndexValue iv = learner.selectAction(currentState, actionsAtState); + currentAction = iv.getIndex(); + currentValue = iv.getValue(); + } + + return new IndexValue(currentAction, currentValue); + } + + public void update(int actionTaken, int newState, double immediateReward) { + update(actionTaken, newState, null, immediateReward); + } + + public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward) { + + IndexValue iv = learner.selectAction(currentState, actionsAtNewState); + int futureAction = iv.getIndex(); + + learner.update(currentState, actionTaken, newState, futureAction, immediateReward); + + prevState = this.currentState; + this.prevAction = actionTaken; + + currentAction = futureAction; + currentState = newState; + } + + public SarsaLearner getLearner() { + return learner; + } + + public void setLearner(SarsaLearner learner) { + this.learner = learner; + } + + public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); + } + + public SarsaAgent(int stateCount, int actionCount) { + learner = new SarsaLearner(stateCount, actionCount); + } + + public SarsaAgent(SarsaLearner learner) { + this.learner = learner; + } + + public SarsaAgent() { + + } + + public void enableEligibilityTrace(double lambda) { + SarsaLambdaLearner acll = new SarsaLambdaLearner(learner); + acll.setLambda(lambda); + learner = acll; + } + + public SarsaAgent makeCopy() { + SarsaAgent clone = new SarsaAgent(); + clone.copy(this); + return clone; + } + + public void copy(SarsaAgent rhs) { + learner.copy(rhs.learner); + currentAction = rhs.currentAction; + currentState = rhs.currentState; + prevAction = rhs.prevAction; + prevState = rhs.prevState; + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof SarsaAgent) { + SarsaAgent rhs = (SarsaAgent) obj; + return prevAction == rhs.prevAction + && prevState == rhs.prevState + && currentAction == rhs.currentAction + && currentState == rhs.currentState + && learner.equals(rhs.learner); + } + return false; + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java index e51543e..0a94fe4 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java @@ -1,130 +1,127 @@ package com.github.chen0040.rl.learning.sarsa; - import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; - /** * Created by xschen on 9/28/2015 0028. */ public class SarsaLambdaLearner extends SarsaLearner { - private double lambda = 0.9; - private Matrix e; - private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; - - public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; - } - - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { - this.traceUpdateMode = traceUpdateMode; - } - - public double getLambda(){ - return lambda; - } - - public void setLambda(double lambda){ - this.lambda = lambda; - } - - @Override - public Object clone(){ - SarsaLambdaLearner clone = new SarsaLambdaLearner(); - clone.copy(this); - return clone; - } - - @Override - public void copy(SarsaLearner rhs){ - super.copy(rhs); - - SarsaLambdaLearner rhs2 = (SarsaLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; - } - - @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ - return false; - } - - if(obj instanceof SarsaLambdaLearner){ - SarsaLambdaLearner rhs = (SarsaLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; - } - - return false; - } - - public SarsaLambdaLearner(){ - super(); - } - - public SarsaLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(SarsaLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); - } - - public Matrix getEligibility() - { - return e; - } - - public void setEligibility(Matrix e){ - this.e = e; - } - - @Override - public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward) - { - // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); - - // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); - - // discount_rate; - double gamma = model.getGamma(); - - // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); - - double td_error = immediateReward + gamma * nextQ - oldQ; - - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); - - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); - - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); - - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); - - model.setQ(stateId, actionId, newQ); - - if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); - } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); - } - } - } - } + private double lambda = 0.9; + private Matrix e; + private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + + public EligibilityTraceUpdateMode getTraceUpdateMode() { + return traceUpdateMode; + } + + public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + this.traceUpdateMode = traceUpdateMode; + } + + public double getLambda() { + return lambda; + } + + public void setLambda(double lambda) { + this.lambda = lambda; + } + + @Override + public Object clone() { + SarsaLambdaLearner clone = new SarsaLambdaLearner(); + clone.copy(this); + return clone; + } + + @Override + public void copy(SarsaLearner rhs) { + super.copy(rhs); + + SarsaLambdaLearner rhs2 = (SarsaLambdaLearner) rhs; + lambda = rhs2.lambda; + e = rhs2.e.makeCopy(); + traceUpdateMode = rhs2.traceUpdateMode; + } + + @Override + public boolean equals(Object obj) { + if (!super.equals(obj)) { + return false; + } + + if (obj instanceof SarsaLambdaLearner) { + SarsaLambdaLearner rhs = (SarsaLambdaLearner) obj; + return rhs.lambda == lambda && e + .equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + } + + return false; + } + + public SarsaLambdaLearner() { + super(); + } + + public SarsaLambdaLearner(int stateCount, int actionCount) { + super(stateCount, actionCount); + e = new Matrix(stateCount, actionCount); + } + + public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + e = new Matrix(stateCount, actionCount); + } + + public SarsaLambdaLearner(SarsaLearner learner) { + copy(learner); + e = new Matrix(model.getStateCount(), model.getActionCount()); + } + + public Matrix getEligibility() { + return e; + } + + public void setEligibility(Matrix e) { + this.e = e; + } + + @Override + public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward) { + // old_value is $Q_t(s_t, a_t)$ + double oldQ = model.getQ(currentStateId, currentActionId); + + // learning_rate; + double alpha = model.getAlpha(currentStateId, currentActionId); + + // discount_rate; + double gamma = model.getGamma(); + + // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ + double nextQ = model.getQ(nextStateId, nextActionId); + + double td_error = immediateReward + gamma * nextQ - oldQ; + + int stateCount = model.getStateCount(); + int actionCount = model.getActionCount(); + + e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = model.getQ(stateId, actionId); + + double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + + model.setQ(stateId, actionId, newQ); + + if (actionId != currentActionId) { + e.set(currentStateId, actionId, 0); + } else { + e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + } + } + } + } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java index 7fef780..fd04dcb 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java @@ -1,160 +1,157 @@ package com.github.chen0040.rl.learning.sarsa; - -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; +import com.google.gson.Gson; import java.io.Serializable; import java.util.Random; import java.util.Set; - /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is + * an off-policy TD control algorithm Q is known as the quality of state-action combination, note + * that it is different from utility of a state */ -public class SarsaLearner implements Serializable,Cloneable { - protected QModel model; - private ActionSelectionStrategy actionSelectionStrategy; - - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } - - public static SarsaLearner fromJson(String json){ - return JSON.parseObject(json, SarsaLearner.class); - } - - public SarsaLearner makeCopy(){ - SarsaLearner clone = new SarsaLearner(); - clone.copy(this); - return clone; - } +public class SarsaLearner implements Serializable, Cloneable { + protected QModel model; + private ActionSelectionStrategy actionSelectionStrategy; - public void copy(SarsaLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); - } + public String toJson() { + return new Gson().toJson(this); +// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof SarsaLearner){ - SarsaLearner rhs = (SarsaLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; - } + public static SarsaLearner fromJson(String json) { + return new Gson().fromJson(json, SarsaLearner.class); +// return JSON.parseObject(json, SarsaLearner.class); + } - public QModel getModel() { - return model; - } + public SarsaLearner makeCopy() { + SarsaLearner clone = new SarsaLearner(); + clone.copy(this); + return clone; + } - public void setModel(QModel model) { - this.model = model; - } + public void copy(SarsaLearner rhs) { + model = rhs.model.makeCopy(); + actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy) + .clone(); + } - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); - } + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof SarsaLearner) { + SarsaLearner rhs = (SarsaLearner) obj; + if (!model.equals(rhs.model)) return false; + return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; + } - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); - } + public QModel getModel() { + return model; + } - public SarsaLearner(){ + public void setModel(QModel model) { + this.model = model; + } - } + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + } - public SarsaLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.7, 0.1); - } + public void setActionSelection(String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } - public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ - this.model = model; - this.actionSelectionStrategy = actionSelectionStrategy; - } + public SarsaLearner() { - public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - } + } - public static void main(String[] args){ - int stateCount = 100; - int actionCount = 10; + public SarsaLearner(int stateCount, int actionCount) { + this(stateCount, actionCount, 0.1, 0.7, 0.1); + } - SarsaLearner learner = new SarsaLearner(stateCount, actionCount); + public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy) { + this.model = model; + this.actionSelectionStrategy = actionSelectionStrategy; + } - double reward = 0; // reward gained by transiting from prevState to currentState - Random random = new Random(); - int currentStateId = random.nextInt(stateCount); - int currentActionId = learner.selectAction(currentStateId).getIndex(); + public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) { + model = new QModel(stateCount, actionCount, initialQ); + model.setAlpha(alpha); + model.setGamma(gamma); + actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + } - for(int time=0; time < 1000; ++time){ + public static void main(String[] args) { + int stateCount = 100; + int actionCount = 10; - System.out.println("Controller does action-"+currentActionId); + SarsaLearner learner = new SarsaLearner(stateCount, actionCount); - int newStateId = random.nextInt(actionCount); - reward = random.nextDouble(); + double reward = 0; // reward gained by transiting from prevState to currentState + Random random = new Random(); + int currentStateId = random.nextInt(stateCount); + int currentActionId = learner.selectAction(currentStateId).getIndex(); - System.out.println("Now the new state is " + newStateId); - System.out.println("Controller receives Reward = " + reward); + for (int time = 0; time < 1000; ++time) { - int futureActionId = learner.selectAction(newStateId).getIndex(); + System.out.println("Controller does action-" + currentActionId); - System.out.println("Controller is expected to do action-"+futureActionId); + int newStateId = random.nextInt(actionCount); + reward = random.nextDouble(); - learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Controller receives Reward = " + reward); - currentStateId = newStateId; - currentActionId = futureActionId; - } - } + int futureActionId = learner.selectAction(newStateId).getIndex(); + System.out.println("Controller is expected to do action-" + futureActionId); - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); - } + learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward); - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); - } + currentStateId = newStateId; + currentActionId = futureActionId; + } + } - public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward) - { - // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); + public IndexValue selectAction(int stateId, Set actionsAtState) { + return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + } - // learning_rate; - double alpha = model.getAlpha(stateId, actionId); + public IndexValue selectAction(int stateId) { + return selectAction(stateId, null); + } - // discount_rate; - double gamma = model.getGamma(); + public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward) { + // old_value is $Q_t(s_t, a_t)$ + double oldQ = model.getQ(stateId, actionId); - // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); + // learning_rate; + double alpha = model.getAlpha(stateId, actionId); - // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value - // old_value = oldQ - // temporal_difference = learned_value - old_value - // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); + // discount_rate; + double gamma = model.getGamma(); - // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); - } + // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ + double nextQ = model.getQ(nextStateId, nextActionId); + // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value + // old_value = oldQ + // temporal_difference = learned_value - old_value + // new_value = old_value + learning_rate * temporal_difference + double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); + // new_value is $Q_{t+1}(s_t, a_t)$ + model.setQ(stateId, actionId, newQ); + } } diff --git a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java index e25380f..dd6dc71 100644 --- a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java +++ b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java @@ -4,6 +4,6 @@ * Created by xschen on 9/28/2015 0028. */ public enum EligibilityTraceUpdateMode { - ReplaceTrace, - AccumulateTrace + ReplaceTrace, + AccumulateTrace } diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index 2d314a1..22cbde9 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -1,158 +1,166 @@ package com.github.chen0040.rl.models; - import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; + import lombok.Getter; import lombok.Setter; import java.util.*; - /** - * @author xschen - * 9/27/2015 0027. - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that + * it is different from utility of a state */ @Getter @Setter public class QModel { - /** - * Q value for (state_id, action_id) pair - * Q is known as the quality of state-action combination, note that it is different from utility of a state - */ - private Matrix Q; - /** - * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) - */ - private Matrix alphaMatrix; - - /** - * discount factor - */ - private double gamma = 0.7; - - private int stateCount; - private int actionCount; - - public QModel(int stateCount, int actionCount, double initialQ){ - this.stateCount = stateCount; - this.actionCount = actionCount; - Q = new Matrix(stateCount,actionCount); - alphaMatrix = new Matrix(stateCount, actionCount); - Q.setAll(initialQ); - alphaMatrix.setAll(0.1); - } - - public QModel(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1); - } - - public QModel(){ - - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof QModel){ - QModel rhs2 = (QModel)rhs; - - - if(gamma != rhs2.gamma) return false; - - - if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false; - - if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false; - if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false; - - return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix))); - - } - return false; - } - - public QModel makeCopy(){ - QModel clone = new QModel(); - clone.copy(this); - return clone; - } - - public void copy(QModel rhs){ - gamma = rhs.gamma; - stateCount = rhs.stateCount; - actionCount = rhs.actionCount; - Q = rhs.Q==null ? null : rhs.Q.makeCopy(); - alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); - } - - - public double getQ(int stateId, int actionId){ - return Q.get(stateId, actionId); - } - - - public void setQ(int stateId, int actionId, double Qij){ - Q.set(stateId, actionId, Qij); - } - - - public double getAlpha(int stateId, int actionId){ - return alphaMatrix.get(stateId, actionId); - } - - - public void setAlpha(double defaultAlpha) { - this.alphaMatrix.setAll(defaultAlpha); - } - - - public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){ - Vec rowVector = Q.rowAt(stateId); - return rowVector.indexWithMaxValue(actionsAtState); - } - - private void reset(double initialQ){ - Q.setAll(initialQ); - } - - - public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) { - Vec rowVector = Q.rowAt(stateId); - double sum = 0; - - if(actionsAtState==null){ - actionsAtState = new HashSet<>(); - for(int i=0; i < actionCount; ++i){ - actionsAtState.add(i); - } - } - - List actions = new ArrayList<>(); - for(Integer actionId : actionsAtState){ - actions.add(actionId); - } - - double[] acc = new double[actions.size()]; - for(int i=0; i < actions.size(); ++i){ - sum += rowVector.get(actions.get(i)); - acc[i] = sum; - } - - - double r = random.nextDouble() * sum; - - IndexValue result = new IndexValue(); - for(int i=0; i < actions.size(); ++i){ - if(acc[i] >= r){ - int actionId = actions.get(i); - result.setIndex(actionId); - result.setValue(rowVector.get(actionId)); - break; - } - } - - return result; - } + /** + * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, + * note that it is different from utility of a state + */ + private Matrix Q; + /** + * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) + */ + private Matrix alphaMatrix; + + /** + * discount factor + */ + private double gamma = 0.7; + + private int stateCount; + private int actionCount; + + public QModel(int stateCount, int actionCount, double initialQ) { + this.stateCount = stateCount; + this.actionCount = actionCount; + Q = new Matrix(stateCount, actionCount); + alphaMatrix = new Matrix(stateCount, actionCount); + Q.setAll(initialQ); + alphaMatrix.setAll(0.1); + } + + public QModel(int stateCount, int actionCount) { + this(stateCount, actionCount, 0.1); + } + + public QModel() { + + } + + @Override + public boolean equals(Object rhs) { + if (rhs != null && rhs instanceof QModel) { + QModel rhs2 = (QModel) rhs; + + if (gamma != rhs2.gamma) return false; + + if (stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false; + + if ((Q != null && rhs2.Q == null) || (Q == null && rhs2.Q != null)) return false; + if ((alphaMatrix != null && rhs2.alphaMatrix == null) || (alphaMatrix == null && rhs2.alphaMatrix != null)) + return false; + + return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix + .equals(rhs2.alphaMatrix))); + + } + return false; + } + + public QModel makeCopy() { + QModel clone = new QModel(); + clone.copy(this); + return clone; + } + + public void copy(QModel rhs) { + gamma = rhs.gamma; + stateCount = rhs.stateCount; + actionCount = rhs.actionCount; + Q = rhs.Q == null ? null : rhs.Q.makeCopy(); + alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); + } + + public double getQ(int stateId, int actionId) { + return Q.get(stateId, actionId); + } + + public void setQ(int stateId, int actionId, double Qij) { + Q.set(stateId, actionId, Qij); + } + + public double getAlpha(int stateId, int actionId) { + return alphaMatrix.get(stateId, actionId); + } + + public void setAlpha(double defaultAlpha) { + this.alphaMatrix.setAll(defaultAlpha); + } + + public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState) { + Vec rowVector = Q.rowAt(stateId); + return rowVector.indexWithMaxValue(actionsAtState); + } + + private void reset(double initialQ) { + Q.setAll(initialQ); + } + + public IndexValue actionWithSoftMaxQAtState(int stateId, Set actionsAtState, Random random) { + Vec rowVector = Q.rowAt(stateId); + double sum = 0; + + if (actionsAtState == null) { + actionsAtState = new HashSet<>(); + for (int i = 0; i < actionCount; ++i) { + actionsAtState.add(i); + } + } + + List actions = new ArrayList<>(actionsAtState); + + double[] acc = new double[actions.size()]; + for (int i = 0; i < actions.size(); ++i) { + sum += rowVector.get(actions.get(i)); + acc[i] = sum; + } + + double r = random.nextDouble() * sum; + + IndexValue result = new IndexValue(); + for (int i = 0; i < actions.size(); ++i) { + if (acc[i] >= r) { + int actionId = actions.get(i); + result.setIndex(actionId); + result.setValue(rowVector.get(actionId)); + break; + } + } + + return result; + } + + public int getActionCount() { + return this.actionCount; + } + + public int getStateCount() { + return stateCount; + } + + public double getGamma() { + return this.gamma; + } + + public void setGamma(double gamma) { + this.gamma = gamma; + } + + public Matrix getAlphaMatrix() { + return this.alphaMatrix; + } } diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java index cff1859..8188236 100644 --- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java +++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java @@ -1,91 +1,91 @@ package com.github.chen0040.rl.models; import com.github.chen0040.rl.utils.Vec; + import lombok.Getter; import lombok.Setter; import java.io.Serializable; - /** - * @author xschen - * 9/27/2015 0027. - * Utility value of a state $U(s)$ is the expected long term reward of state $s$ given the sequence of reward and the optimal policy - * Utility value $U(s)$ at state $s$ can be obtained by the Bellman equation - * Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ - * where s' is the possible transitioned state given that action $a$ is applied at state $s$ - * where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ - * where $\sum_{s'} T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ - * where $max_a \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ is applied at state $s$ + * @author xschen 9/27/2015 0027. Utility value of a state $U(s)$ is the expected long term reward + * of state $s$ given the sequence of reward and the optimal policy Utility value $U(s)$ at state + * $s$ can be obtained by the Bellman equation Bellman Equtation states that $U(s) = R(s) + \gamma * + * max_a \sum_{s'} T(s,a,s')U(s')$ where s' is the possible transitioned state given that action $a$ + * is applied at state $s$ where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ + * given that action $a$ is applied at state $s$ where $\sum_{s'} T(s,a,s')U(s')$ is the expected + * long term reward given that action $a$ is applied at state $s$ where $max_a \sum_{s'} + * T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ + * is applied at state $s$ */ @Getter @Setter public class UtilityModel implements Serializable { - private Vec U; - private int stateCount; - private int actionCount; - - public void setU(Vec U){ - this.U = U; - } - - public Vec getU() { - return U; - } - - public double getU(int stateId){ - return U.get(stateId); - } - - public int getStateCount() { - return stateCount; - } - - public int getActionCount() { - return actionCount; - } - - public UtilityModel(int stateCount, int actionCount, double initialU){ - this.stateCount = stateCount; - this.actionCount = actionCount; - U = new Vec(stateCount); - U.setAll(initialU); - } - - public UtilityModel(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1); - } - - public UtilityModel(){ - - } - - public void copy(UtilityModel rhs){ - U = rhs.U==null ? null : rhs.U.makeCopy(); - actionCount = rhs.actionCount; - stateCount = rhs.stateCount; - } - - public UtilityModel makeCopy(){ - UtilityModel clone = new UtilityModel(); - clone.copy(this); - return clone; - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof UtilityModel){ - UtilityModel rhs2 = (UtilityModel)rhs; - if(actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false; - - if((U==null && rhs2.U!=null) && (U!=null && rhs2.U ==null)) return false; - return !(U != null && !U.equals(rhs2.U)); - - } - return false; - } - - public void reset(double initialU){ - U.setAll(initialU); - } + private Vec U; + private int stateCount; + private int actionCount; + + public void setU(Vec U) { + this.U = U; + } + + public Vec getU() { + return U; + } + + public double getU(int stateId) { + return U.get(stateId); + } + + public int getStateCount() { + return stateCount; + } + + public int getActionCount() { + return actionCount; + } + + public UtilityModel(int stateCount, int actionCount, double initialU) { + this.stateCount = stateCount; + this.actionCount = actionCount; + U = new Vec(stateCount); + U.setAll(initialU); + } + + public UtilityModel(int stateCount, int actionCount) { + this(stateCount, actionCount, 0.1); + } + + public UtilityModel() { + + } + + public void copy(UtilityModel rhs) { + U = rhs.U == null ? null : rhs.U.makeCopy(); + actionCount = rhs.actionCount; + stateCount = rhs.stateCount; + } + + public UtilityModel makeCopy() { + UtilityModel clone = new UtilityModel(); + clone.copy(this); + return clone; + } + + @Override + public boolean equals(Object rhs) { + if (rhs != null && rhs instanceof UtilityModel) { + UtilityModel rhs2 = (UtilityModel) rhs; + if (actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false; + + if ((U == null && rhs2.U != null) && (U != null && rhs2.U == null)) return false; + return !(U != null && !U.equals(rhs2.U)); + + } + return false; + } + + public void reset(double initialU) { + U.setAll(initialU); + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java index e840bc1..145b1e4 100644 --- a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java @@ -4,11 +4,11 @@ * Created by xschen on 10/11/2015 0011. */ public class DoubleUtils { - public static boolean equals(double a1, double a2){ - return Math.abs(a1-a2) < 1e-10; - } + public static boolean equals(double a1, double a2) { + return Math.abs(a1 - a2) < 1e-10; + } - public static boolean isZero(double a){ - return a < 1e-20; - } + public static boolean isZero(double a) { + return a < 1e-20; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java index 66c2bf6..621e7a3 100644 --- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java +++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java @@ -1,46 +1,59 @@ package com.github.chen0040.rl.utils; - import lombok.Getter; import lombok.Setter; - /** * Created by xschen on 6/5/2017. */ @Getter @Setter public class IndexValue { - private int index; - private double value; - - public IndexValue(){ - - } - - public IndexValue(int index, double value){ - this.index = index; - this.value = value; - } - - public IndexValue makeCopy(){ - IndexValue clone = new IndexValue(); - clone.setValue(value); - clone.setIndex(index); - return clone; - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof IndexValue){ - IndexValue rhs2 = (IndexValue)rhs; - return index == rhs2.index && value == rhs2.value; - } - return false; - } - - public boolean isValid(){ - return index != -1; - } - + private int index; + private double value; + + public IndexValue() { + + } + + public IndexValue(int index, double value) { + this.index = index; + this.value = value; + } + + public IndexValue makeCopy() { + IndexValue clone = new IndexValue(); + clone.setValue(value); + clone.setIndex(index); + return clone; + } + + public void setIndex(int index) { + this.index = index; + } + + @Override + public boolean equals(Object rhs) { + if (rhs != null && rhs instanceof IndexValue) { + IndexValue rhs2 = (IndexValue) rhs; + return index == rhs2.index && value == rhs2.value; + } + return false; + } + + public boolean isValid() { + return index != -1; + } + + public void setValue(double v) { + this.value = v; + } + + public int getIndex() { + return this.index; + } + + public double getValue() { + return this.value; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java index cd42bd5..2421a8b 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java +++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.utils; -import com.alibaba.fastjson.annotation.JSONField; +//import com.alibaba.fastjson.annotation.JSONField; + import lombok.Getter; import lombok.Setter; @@ -10,234 +11,228 @@ import java.util.List; import java.util.Map; - /** * Created by xschen on 9/27/2015 0027. */ @Getter @Setter public class Matrix implements Serializable { - private Map rows = new HashMap<>(); - private int rowCount; - private int columnCount; - private double defaultValue; - - public Matrix(){ - - } - - public Matrix(double[][] A){ - for(int i = 0; i < A.length; ++i){ - double[] B = A[i]; - for(int j=0; j < B.length; ++j){ - set(i, j, B[j]); - } - } - } - - public void setRow(int rowIndex, Vec rowVector){ - rowVector.setId(rowIndex); - rows.put(rowIndex, rowVector); - } - - - public static Matrix identity(int dimension){ - Matrix m = new Matrix(dimension, dimension); - for(int i=0; i < m.getRowCount(); ++i){ - m.set(i, i, 1); - } - return m; - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Matrix){ - Matrix rhs2 = (Matrix)rhs; - if(rowCount != rhs2.rowCount || columnCount != rhs2.columnCount){ - return false; - } - - if(defaultValue == rhs2.defaultValue) { - for (Integer index : rows.keySet()) { - if (!rhs2.rows.containsKey(index)) return false; - if (!rows.get(index).equals(rhs2.rows.get(index))) { - System.out.println("failed!"); - return false; - } - } - - for (Integer index : rhs2.rows.keySet()) { - if (!rows.containsKey(index)) return false; - if (!rhs2.rows.get(index).equals(rows.get(index))) { - System.out.println("failed! 22"); - return false; - } - } - } else { - - for(int i=0; i < rowCount; ++i) { - for(int j=0; j < columnCount; ++j) { - if(this.get(i, j) != rhs2.get(i, j)){ - return false; - } - } - } - } - - return true; - } - - return false; - } - - public Matrix makeCopy(){ - Matrix clone = new Matrix(rowCount, columnCount); - clone.copy(this); - return clone; - } - - public void copy(Matrix rhs){ - rowCount = rhs.rowCount; - columnCount = rhs.columnCount; - defaultValue = rhs.defaultValue; - - rows.clear(); - - for(Map.Entry entry : rhs.rows.entrySet()){ - rows.put(entry.getKey(), entry.getValue().makeCopy()); - } - } - - - - public void set(int rowIndex, int columnIndex, double value){ - Vec row = rowAt(rowIndex); - row.set(columnIndex, value); - if(rowIndex >= rowCount) { rowCount = rowIndex+1; } - if(columnIndex >= columnCount) { columnCount = columnIndex + 1; } - } - - - - public Matrix(int rowCount, int columnCount){ - this.rowCount = rowCount; - this.columnCount = columnCount; - this.defaultValue = 0; - } - - public Vec rowAt(int rowIndex){ - Vec row = rows.get(rowIndex); - if(row == null){ - row = new Vec(columnCount); - row.setAll(defaultValue); - row.setId(rowIndex); - rows.put(rowIndex, row); - } - return row; - } - - public void setAll(double value){ - defaultValue = value; - for(Vec row : rows.values()){ - row.setAll(value); - } - } - - public double get(int rowIndex, int columnIndex) { - Vec row= rowAt(rowIndex); - return row.get(columnIndex); - } - - public List columnVectors() - { - Matrix A = this; - int n = A.getColumnCount(); - int rowCount = A.getRowCount(); - - List Acols = new ArrayList(); - - for (int c = 0; c < n; ++c) - { - Vec Acol = new Vec(rowCount); - Acol.setAll(defaultValue); - Acol.setId(c); - - for (int r = 0; r < rowCount; ++r) - { - Acol.set(r, A.get(r, c)); - } - Acols.add(Acol); - } - return Acols; - } - - public Matrix multiply(Matrix rhs) - { - if(this.getColumnCount() != rhs.getRowCount()){ - System.err.println("A.columnCount must be equal to B.rowCount in multiplication"); - return null; - } - - Vec row1; - Vec col2; - - Matrix result = new Matrix(getRowCount(), rhs.getColumnCount()); - result.setAll(defaultValue); - - List rhsColumns = rhs.columnVectors(); - - for (Map.Entry entry : rows.entrySet()) - { - int r1 = entry.getKey(); - row1 = entry.getValue(); - for (int c2 = 0; c2 < rhsColumns.size(); ++c2) - { - col2 = rhsColumns.get(c2); - result.set(r1, c2, row1.multiply(col2)); - } - } - - return result; - } - - @JSONField(serialize = false) - public boolean isSymmetric(){ - if (getRowCount() != getColumnCount()) return false; - - for (Map.Entry rowEntry : rows.entrySet()) - { - int row = rowEntry.getKey(); - Vec rowVec = rowEntry.getValue(); - - for (Integer col : rowVec.getData().keySet()) - { - if (row == col.intValue()) continue; - if(DoubleUtils.equals(rowVec.get(col), this.get(col, row))){ - return false; - } - } - } - - return true; - } - - public Vec multiply(Vec rhs) - { - if(this.getColumnCount() != rhs.getDimension()){ - System.err.println("columnCount must be equal to the size of the vector for multiplication"); - } - - Vec row1; - Vec result = new Vec(getRowCount()); - for (Map.Entry entry : rows.entrySet()) - { - row1 = entry.getValue(); - result.set(entry.getKey(), row1.multiply(rhs)); - } - return result; - } - - - + private Map rows = new HashMap<>(); + private int rowCount; + private int columnCount; + private double defaultValue; + + public Matrix() { + + } + + public Matrix(double[][] A) { + for (int i = 0; i < A.length; ++i) { + double[] B = A[i]; + for (int j = 0; j < B.length; ++j) { + set(i, j, B[j]); + } + } + } + + public void setRow(int rowIndex, Vec rowVector) { + rowVector.setId(rowIndex); + rows.put(rowIndex, rowVector); + } + + public static Matrix identity(int dimension) { + Matrix m = new Matrix(dimension, dimension); + for (int i = 0; i < m.getRowCount(); ++i) { + m.set(i, i, 1); + } + return m; + } + + public int getRowCount() { + return this.rowCount; + } + + @Override + public boolean equals(Object rhs) { + if (rhs != null && rhs instanceof Matrix) { + Matrix rhs2 = (Matrix) rhs; + if (rowCount != rhs2.rowCount || columnCount != rhs2.columnCount) { + return false; + } + + if (defaultValue == rhs2.defaultValue) { + for (Integer index : rows.keySet()) { + if (!rhs2.rows.containsKey(index)) return false; + if (!rows.get(index).equals(rhs2.rows.get(index))) { + System.out.println("failed!"); + return false; + } + } + + for (Integer index : rhs2.rows.keySet()) { + if (!rows.containsKey(index)) return false; + if (!rhs2.rows.get(index).equals(rows.get(index))) { + System.out.println("failed! 22"); + return false; + } + } + } else { + + for (int i = 0; i < rowCount; ++i) { + for (int j = 0; j < columnCount; ++j) { + if (this.get(i, j) != rhs2.get(i, j)) { + return false; + } + } + } + } + + return true; + } + + return false; + } + + public Matrix makeCopy() { + Matrix clone = new Matrix(rowCount, columnCount); + clone.copy(this); + return clone; + } + + public void copy(Matrix rhs) { + rowCount = rhs.rowCount; + columnCount = rhs.columnCount; + defaultValue = rhs.defaultValue; + + rows.clear(); + + for (Map.Entry entry : rhs.rows.entrySet()) { + rows.put(entry.getKey(), entry.getValue().makeCopy()); + } + } + + public void set(int rowIndex, int columnIndex, double value) { + Vec row = rowAt(rowIndex); + row.set(columnIndex, value); + if (rowIndex >= rowCount) { + rowCount = rowIndex + 1; + } + if (columnIndex >= columnCount) { + columnCount = columnIndex + 1; + } + } + + public Matrix(int rowCount, int columnCount) { + this.rowCount = rowCount; + this.columnCount = columnCount; + this.defaultValue = 0; + } + + public Vec rowAt(int rowIndex) { + Vec row = rows.get(rowIndex); + if (row == null) { + row = new Vec(columnCount); + row.setAll(defaultValue); + row.setId(rowIndex); + rows.put(rowIndex, row); + } + return row; + } + + public void setAll(double value) { + defaultValue = value; + for (Vec row : rows.values()) { + row.setAll(value); + } + } + + public double get(int rowIndex, int columnIndex) { + Vec row = rowAt(rowIndex); + return row.get(columnIndex); + } + + public List columnVectors() { + Matrix A = this; + int n = A.getColumnCount(); + int rowCount = A.getRowCount(); + + List Acols = new ArrayList<>(); + + for (int c = 0; c < n; ++c) { + Vec Acol = new Vec(rowCount); + Acol.setAll(defaultValue); + Acol.setId(c); + + for (int r = 0; r < rowCount; ++r) { + Acol.set(r, A.get(r, c)); + } + Acols.add(Acol); + } + return Acols; + } + + public int getColumnCount() { + return this.columnCount; + } + + public Matrix multiply(Matrix rhs) { + if (this.getColumnCount() != rhs.getRowCount()) { + System.err.println("A.columnCount must be equal to B.rowCount in multiplication"); + return null; + } + + Vec row1; + Vec col2; + + Matrix result = new Matrix(getRowCount(), rhs.getColumnCount()); + result.setAll(defaultValue); + + List rhsColumns = rhs.columnVectors(); + + for (Map.Entry entry : rows.entrySet()) { + int r1 = entry.getKey(); + row1 = entry.getValue(); + for (int c2 = 0; c2 < rhsColumns.size(); ++c2) { + col2 = rhsColumns.get(c2); + result.set(r1, c2, row1.multiply(col2)); + } + } + + return result; + } + +// @JSONField(serialize = false) + public boolean isSymmetric() { + if (getRowCount() != getColumnCount()) return false; + + for (Map.Entry rowEntry : rows.entrySet()) { + int row = rowEntry.getKey(); + Vec rowVec = rowEntry.getValue(); + + for (Integer col : rowVec.getData().keySet()) { + if (row == col) continue; + if (DoubleUtils.equals(rowVec.get(col), this.get(col, row))) { + return false; + } + } + } + + return true; + } + + public Vec multiply(Vec rhs) { + if (this.getColumnCount() != rhs.getDimension()) { + System.err + .println("columnCount must be equal to the size of the vector for multiplication"); + } + + Vec row1; + Vec result = new Vec(getRowCount()); + for (Map.Entry entry : rows.entrySet()) { + row1 = entry.getValue(); + result.set(entry.getKey(), row1.multiply(rhs)); + } + return result; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java index e43c28b..2bc8631 100644 --- a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java @@ -2,28 +2,24 @@ import java.util.List; - /** * Created by xschen on 10/11/2015 0011. */ public class MatrixUtils { - /** - * Convert a list of column vectors into a matrix - */ - public static Matrix matrixFromColumnVectors(List R) - { - int n = R.size(); - int m = R.get(0).getDimension(); + /** + * Convert a list of column vectors into a matrix + */ + public static Matrix matrixFromColumnVectors(List R) { + int n = R.size(); + int m = R.get(0).getDimension(); - Matrix T = new Matrix(m, n); - for (int c = 0; c < n; ++c) - { - Vec Rcol = R.get(c); - for (int r : Rcol.getData().keySet()) - { - T.set(r, c, Rcol.get(r)); - } - } - return T; - } + Matrix T = new Matrix(m, n); + for (int c = 0; c < n; ++c) { + Vec Rcol = R.get(c); + for (int r : Rcol.getData().keySet()) { + T.set(r, c, Rcol.get(r)); + } + } + return T; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java index b4895ea..f959a85 100644 --- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java +++ b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java @@ -1,56 +1,58 @@ package com.github.chen0040.rl.utils; +import java.util.Objects; + /** * Created by xschen on 10/11/2015 0011. */ public class TupleTwo { - private T1 item1; - private T2 item2; - - public TupleTwo(T1 item1, T2 item2){ - this.item1 = item1; - this.item2 = item2; - } - - public T1 getItem1() { - return item1; - } - - public void setItem1(T1 item1) { - this.item1 = item1; - } - - public T2 getItem2() { - return item2; - } - - public void setItem2(T2 item2) { - this.item2 = item2; - } - - public static TupleTwo create(U1 item1, U2 item2){ - return new TupleTwo(item1, item2); - } - - - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - TupleTwo tupleTwo = (TupleTwo) o; - - if (item1 != null ? !item1.equals(tupleTwo.item1) : tupleTwo.item1 != null) - return false; - return item2 != null ? item2.equals(tupleTwo.item2) : tupleTwo.item2 == null; - - } - - - @Override public int hashCode() { - int result = item1 != null ? item1.hashCode() : 0; - result = 31 * result + (item2 != null ? item2.hashCode() : 0); - return result; - } + private T1 item1; + private T2 item2; + + public TupleTwo(T1 item1, T2 item2) { + this.item1 = item1; + this.item2 = item2; + } + + public T1 getItem1() { + return item1; + } + + public void setItem1(T1 item1) { + this.item1 = item1; + } + + public T2 getItem2() { + return item2; + } + + public void setItem2(T2 item2) { + this.item2 = item2; + } + + public static TupleTwo create(U1 item1, U2 item2) { + return new TupleTwo<>(item1, item2); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + TupleTwo tupleTwo = (TupleTwo) o; + + if (!Objects.equals(item1, tupleTwo.item1)) + return false; + return Objects.equals(item2, tupleTwo.item2); + + } + + @Override + public int hashCode() { + int result = item1 != null ? item1.hashCode() : 0; + result = 31 * result + (item2 != null ? item2.hashCode() : 0); + return result; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index 4699d0e..1152464 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -9,341 +9,329 @@ import java.util.Map; import java.util.Set; - /** * Created by xschen on 9/27/2015 0027. */ @Getter @Setter public class Vec implements Serializable { - private Map data = new HashMap(); - private int dimension; - private double defaultValue; - private int id = -1; - - public Vec(){ - - } - - public Vec(double[] v){ - for(int i=0; i < v.length; ++i){ - set(i, v[i]); - } - } - - public Vec(int dimension){ - this.dimension = dimension; - defaultValue = 0; - } - - public Vec(int dimension, Map data){ - this.dimension = dimension; - defaultValue = 0; - - for(Map.Entry entry : data.entrySet()){ - set(entry.getKey(), entry.getValue()); - } - } - - public Vec makeCopy(){ - Vec clone = new Vec(dimension); - clone.copy(this); - return clone; - } - - public void copy(Vec rhs){ - defaultValue = rhs.defaultValue; - dimension = rhs.dimension; - id = rhs.id; - - data.clear(); - for(Map.Entry entry : rhs.data.entrySet()){ - data.put(entry.getKey(), entry.getValue()); - } - } - - public void set(int i, double value){ - if(value == defaultValue) return; - - data.put(i, value); - if(i >= dimension){ - dimension = i+1; - } - } - - - public double get(int i){ - return data.getOrDefault(i, defaultValue); - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Vec){ - Vec rhs2 = (Vec)rhs; - if(dimension != rhs2.dimension){ - return false; - } - - if(data.size() != rhs2.data.size()){ - return false; - } - - for(Integer index : data.keySet()){ - if(!rhs2.data.containsKey(index)) return false; - if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){ - return false; - } - } - - if(defaultValue != rhs2.defaultValue){ - for(int i=0; i < dimension; ++i){ - if(data.containsKey(i)){ - return false; - } - } - } - - return true; - } - - return false; - } - - public void setAll(double value){ - defaultValue = value; - for(Integer index : data.keySet()){ - data.put(index, defaultValue); - } - } - - public IndexValue indexWithMaxValue(Set indices){ - if(indices == null){ - return indexWithMaxValue(); - }else{ - IndexValue iv = new IndexValue(); - iv.setIndex(-1); - iv.setValue(Double.NEGATIVE_INFINITY); - for(Integer index : indices){ - double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY); - if(value > iv.getValue()){ - iv.setIndex(index); - iv.setValue(value); - } - } - return iv; - } - } - - public IndexValue indexWithMaxValue(){ - IndexValue iv = new IndexValue(); - iv.setIndex(-1); - iv.setValue(Double.NEGATIVE_INFINITY); - - - for(Map.Entry entry : data.entrySet()){ - if(entry.getKey() >= dimension) continue; - - double value = entry.getValue(); - if(value > iv.getValue()){ - iv.setValue(value); - iv.setIndex(entry.getKey()); - } - } - - if(!iv.isValid()){ - iv.setValue(defaultValue); - } else{ - if(iv.getValue() < defaultValue){ - for(int i=0; i < dimension; ++i){ - if(!data.containsKey(i)){ - iv.setValue(defaultValue); - iv.setIndex(i); - break; - } - } - } - } - - return iv; - } - - - - public Vec projectOrthogonal(Iterable vlist) { - Vec b = this; - for(Vec v : vlist) - { - b = b.minus(b.projectAlong(v)); - } - - return b; - } - - public Vec projectOrthogonal(List vlist, Map alpha) { - Vec b = this; - for(int i = 0; i < vlist.size(); ++i) - { - Vec v = vlist.get(i); - double norm_a = v.multiply(v); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); - } - double sigma = multiply(v) / norm_a; - Vec v_parallel = v.multiply(sigma); - - alpha.put(i, sigma); - - b = b.minus(v_parallel); - } - - return b; - } - - public Vec projectAlong(Vec rhs) - { - double norm_a = rhs.multiply(rhs); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); - } - double sigma = multiply(rhs) / norm_a; - return rhs.multiply(sigma); - } - - public Vec multiply(double rhs){ - Vec clone = (Vec)this.makeCopy(); - for(Integer i : data.keySet()){ - clone.data.put(i, rhs * data.get(i)); - } - return clone; - } - - public double multiply(Vec rhs) - { - double productSum = 0; - if(defaultValue == 0) { - for (Map.Entry entry : data.entrySet()) { - productSum += entry.getValue() * rhs.get(entry.getKey()); - } - } else { - for(int i=0; i < dimension; ++i){ - productSum += get(i) * rhs.get(i); - } - } - - return productSum; - } - - public Vec pow(double scalar) - { - Vec result = new Vec(dimension); - for (Map.Entry entry : data.entrySet()) - { - result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar)); - } - return result; - } - - public Vec add(Vec rhs) - { - Vec result = new Vec(dimension); - int index; - for (Map.Entry entry : data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() + rhs.data.get(index)); - } - for(Map.Entry entry : rhs.data.entrySet()){ - index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, entry.getValue() + data.get(index)); - } - - return result; - } - - public Vec minus(Vec rhs) - { - Vec result = new Vec(dimension); - int index; - for (Map.Entry entry : data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() - rhs.data.get(index)); - } - for(Map.Entry entry : rhs.data.entrySet()){ - index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, data.get(index) - entry.getValue()); - } - - return result; - } - - public double sum(){ - double sum = 0; - - for(Map.Entry entry : data.entrySet()){ - sum += entry.getValue(); - } - sum += defaultValue * (dimension - data.size()); - - return sum; - } - - public boolean isZero(){ - return DoubleUtils.isZero(sum()); - } - - public double norm(int level) - { - if (level == 1) - { - double sum = 0; - for (Double val : data.values()) - { - sum += Math.abs(val); - } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.abs(defaultValue) * (dimension - data.size()); - } - return sum; - } - else if (level == 2) - { - double sum = multiply(this); - if(!DoubleUtils.isZero(defaultValue)){ - sum += (dimension - data.size()) * (defaultValue * defaultValue); - } - return Math.sqrt(sum); - } - else - { - double sum = 0; - for (Double val : this.data.values()) - { - sum += Math.pow(Math.abs(val), level); - } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size()); - } - return Math.pow(sum, 1.0 / level); - } - } - - public Vec normalize() - { - double norm = norm(2); // L2 norm is the cartesian distance - if (DoubleUtils.isZero(norm)) - { - return new Vec(dimension); - } - Vec clone = new Vec(dimension); - clone.setAll(defaultValue / norm); - - for (Integer k : data.keySet()) - { - clone.data.put(k, data.get(k) / norm); - } - return clone; - } + private Map data = new HashMap<>(); + private int dimension; + private double defaultValue; + private int id = -1; + + public Vec() { + + } + + public Vec(double[] v) { + for (int i = 0; i < v.length; ++i) { + set(i, v[i]); + } + } + + public Vec(int dimension) { + this.dimension = dimension; + defaultValue = 0; + } + + public Vec(int dimension, Map data) { + this.dimension = dimension; + defaultValue = 0; + + for (Map.Entry entry : data.entrySet()) { + set(entry.getKey(), entry.getValue()); + } + } + + public Vec makeCopy() { + Vec clone = new Vec(dimension); + clone.copy(this); + return clone; + } + + public void copy(Vec rhs) { + defaultValue = rhs.defaultValue; + dimension = rhs.dimension; + id = rhs.id; + + data.clear(); + for (Map.Entry entry : rhs.data.entrySet()) { + data.put(entry.getKey(), entry.getValue()); + } + } + + public void set(int i, double value) { + if (value == defaultValue) return; + + data.put(i, value); + if (i >= dimension) { + dimension = i + 1; + } + } + + public double get(int i) { + return data.getOrDefault(i, defaultValue); + } + + @Override + public boolean equals(Object rhs) { + if (rhs != null && rhs instanceof Vec) { + Vec rhs2 = (Vec) rhs; + if (dimension != rhs2.dimension) { + return false; + } + + if (data.size() != rhs2.data.size()) { + return false; + } + + for (Integer index : data.keySet()) { + if (!rhs2.data.containsKey(index)) return false; + if (!DoubleUtils.equals(data.get(index), rhs2.data.get(index))) { + return false; + } + } + + if (defaultValue != rhs2.defaultValue) { + for (int i = 0; i < dimension; ++i) { + if (data.containsKey(i)) { + return false; + } + } + } + + return true; + } + + return false; + } + + public void setAll(double value) { + defaultValue = value; + for (Integer index : data.keySet()) { + data.put(index, defaultValue); + } + } + + public IndexValue indexWithMaxValue(Set indices) { + if (indices == null) { + return indexWithMaxValue(); + } else { + IndexValue iv = new IndexValue(); + iv.setIndex(-1); + iv.setValue(Double.NEGATIVE_INFINITY); + for (Integer index : indices) { + double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY); + if (value > iv.getValue()) { + iv.setIndex(index); + iv.setValue(value); + } + } + return iv; + } + } + + public IndexValue indexWithMaxValue() { + IndexValue iv = new IndexValue(); + iv.setIndex(-1); + iv.setValue(Double.NEGATIVE_INFINITY); + + for (Map.Entry entry : data.entrySet()) { + if (entry.getKey() >= dimension) continue; + + double value = entry.getValue(); + if (value > iv.getValue()) { + iv.setValue(value); + iv.setIndex(entry.getKey()); + } + } + + if (!iv.isValid()) { + iv.setValue(defaultValue); + } else { + if (iv.getValue() < defaultValue) { + for (int i = 0; i < dimension; ++i) { + if (!data.containsKey(i)) { + iv.setValue(defaultValue); + iv.setIndex(i); + break; + } + } + } + } + + return iv; + } + + public Vec projectOrthogonal(Iterable vlist) { + Vec b = this; + for (Vec v : vlist) { + b = b.minus(b.projectAlong(v)); + } + + return b; + } + + public Vec projectOrthogonal(List vlist, Map alpha) { + Vec b = this; + for (int i = 0; i < vlist.size(); ++i) { + Vec v = vlist.get(i); + double norm_a = v.multiply(v); + + if (DoubleUtils.isZero(norm_a)) { + return new Vec(dimension); + } + double sigma = multiply(v) / norm_a; + Vec v_parallel = v.multiply(sigma); + + alpha.put(i, sigma); + + b = b.minus(v_parallel); + } + + return b; + } + + public Vec projectAlong(Vec rhs) { + double norm_a = rhs.multiply(rhs); + + if (DoubleUtils.isZero(norm_a)) { + return new Vec(dimension); + } + double sigma = multiply(rhs) / norm_a; + return rhs.multiply(sigma); + } + + public Vec multiply(double rhs) { + Vec clone = (Vec) this.makeCopy(); + for (Integer i : data.keySet()) { + clone.data.put(i, rhs * data.get(i)); + } + return clone; + } + + public double multiply(Vec rhs) { + double productSum = 0; + if (defaultValue == 0) { + for (Map.Entry entry : data.entrySet()) { + productSum += entry.getValue() * rhs.get(entry.getKey()); + } + } else { + for (int i = 0; i < dimension; ++i) { + productSum += get(i) * rhs.get(i); + } + } + + return productSum; + } + + public Vec pow(double scalar) { + Vec result = new Vec(dimension); + for (Map.Entry entry : data.entrySet()) { + result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar)); + } + return result; + } + + public Vec add(Vec rhs) { + Vec result = new Vec(dimension); + int index; + for (Map.Entry entry : data.entrySet()) { + index = entry.getKey(); + result.data.put(index, entry.getValue() + rhs.data.get(index)); + } + for (Map.Entry entry : rhs.data.entrySet()) { + index = entry.getKey(); + if (result.data.containsKey(index)) continue; + result.data.put(index, entry.getValue() + data.get(index)); + } + + return result; + } + + public Vec minus(Vec rhs) { + Vec result = new Vec(dimension); + int index; + for (Map.Entry entry : data.entrySet()) { + index = entry.getKey(); + result.data.put(index, entry.getValue() - rhs.data.get(index)); + } + for (Map.Entry entry : rhs.data.entrySet()) { + index = entry.getKey(); + if (result.data.containsKey(index)) continue; + result.data.put(index, data.get(index) - entry.getValue()); + } + + return result; + } + + public double sum() { + double sum = 0; + + for (Map.Entry entry : data.entrySet()) { + sum += entry.getValue(); + } + sum += defaultValue * (dimension - data.size()); + + return sum; + } + + public boolean isZero() { + return DoubleUtils.isZero(sum()); + } + + public double norm(int level) { + if (level == 1) { + double sum = 0; + for (Double val : data.values()) { + sum += Math.abs(val); + } + if (!DoubleUtils.isZero(defaultValue)) { + sum += Math.abs(defaultValue) * (dimension - data.size()); + } + return sum; + } else if (level == 2) { + double sum = multiply(this); + if (!DoubleUtils.isZero(defaultValue)) { + sum += (dimension - data.size()) * (defaultValue * defaultValue); + } + return Math.sqrt(sum); + } else { + double sum = 0; + for (Double val : this.data.values()) { + sum += Math.pow(Math.abs(val), level); + } + if (!DoubleUtils.isZero(defaultValue)) { + sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size()); + } + return Math.pow(sum, 1.0 / level); + } + } + + public Vec normalize() { + double norm = norm(2); // L2 norm is the cartesian distance + if (DoubleUtils.isZero(norm)) { + return new Vec(dimension); + } + Vec clone = new Vec(dimension); + clone.setAll(defaultValue / norm); + + for (Integer k : data.keySet()) { + clone.data.put(k, data.get(k) / norm); + } + return clone; + } + + public void setId(int rowIndex) { + this.id = rowIndex; + } + + public Map getData() { + return this.data; + } + + public int getDimension() { + return this.dimension; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java index 2bbfbaa..da79781 100644 --- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java @@ -3,37 +3,30 @@ import java.util.ArrayList; import java.util.List; - /** * Created by xschen on 10/11/2015 0011. */ public class VectorUtils { - public static List removeZeroVectors(Iterable vlist) - { - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - if (!v.isZero()) - { - vstarlist.add(v); - } - } - - return vstarlist; - } - - public static TupleTwo, List> normalize(Iterable vlist) - { - List norms = new ArrayList(); - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - norms.add(v.norm(2)); - vstarlist.add(v.normalize()); - } - - return TupleTwo.create(vstarlist, norms); - } - + public static List removeZeroVectors(Iterable vlist) { + List vstarlist = new ArrayList<>(); + for (Vec v : vlist) { + if (!v.isZero()) { + vstarlist.add(v); + } + } + + return vstarlist; + } + + public static TupleTwo, List> normalize(Iterable vlist) { + List norms = new ArrayList<>(); + List vstarlist = new ArrayList<>(); + for (Vec v : vlist) { + norms.add(v.norm(2)); + vstarlist.add(v.normalize()); + } + + return TupleTwo.create(vstarlist, norms); + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java index 7b3d094..3c9d7cc 100644 --- a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java @@ -1,46 +1,45 @@ package com.github.chen0040.rl.learning.actorcritic; - import com.github.chen0040.rl.utils.Vec; + import org.testng.annotations.Test; import java.util.Random; import static org.testng.Assert.*; - /** * Created by xschen on 6/5/2017. */ public class ActorCriticAgentUnitTest { - @Test - public void test_learn(){ - int stateCount = 100; - int actionCount = 10; + @Test + public void test_learn() { + int stateCount = 100; + int actionCount = 10; - ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount); - Vec stateValues = new Vec(stateCount); + ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount); + Vec stateValues = new Vec(stateCount); - Random random = new Random(); - agent.start(random.nextInt(stateCount)); - for(int time=0; time < 1000; ++time){ + Random random = new Random(); + agent.start(random.nextInt(stateCount)); + for (int time = 0; time < 1000; ++time) { - int actionId = agent.selectAction(); - System.out.println("Agent does action-"+actionId); + int actionId = agent.selectAction(); + System.out.println("Agent does action-" + actionId); - int newStateId = random.nextInt(actionCount); - double reward = random.nextDouble(); + int newStateId = random.nextInt(actionCount); + double reward = random.nextDouble(); - System.out.println("Now the new state is "+newStateId); - System.out.println("Agent receives Reward = "+reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Agent receives Reward = " + reward); - System.out.println("World state values changed ..."); - for(int stateId = 0; stateId < stateCount; ++stateId){ - stateValues.set(stateId, random.nextDouble()); - } + System.out.println("World state values changed ..."); + for (int stateId = 0; stateId < stateCount; ++stateId) { + stateValues.set(stateId, random.nextDouble()); + } - agent.update(actionId, newStateId, reward, stateValues); - } - } + agent.update(actionId, newStateId, reward, stateValues); + } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java index d6bafac..3c0bdf8 100644 --- a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.learning.actorcritic; import com.github.chen0040.rl.utils.Vec; + import org.testng.annotations.Test; import java.util.Random; @@ -12,39 +13,39 @@ */ public class ActorCriticLearnerUnitTest { - @Test - public void test_learn(){ - int stateCount = 100; - int actionCount = 10; + @Test + public void test_learn() { + int stateCount = 100; + int actionCount = 10; - ActorCriticLearner learner = new ActorCriticLearner(stateCount, actionCount); - final Vec stateValues = new Vec(stateCount); + ActorCriticLearner learner = new ActorCriticLearner(stateCount, actionCount); + final Vec stateValues = new Vec(stateCount); - Random random = new Random(); - int currentStateId = random.nextInt(stateCount); - for(int time=0; time < 1000; ++time){ + Random random = new Random(); + int currentStateId = random.nextInt(stateCount); + for (int time = 0; time < 1000; ++time) { - int actionId = learner.selectAction(currentStateId); - System.out.println("Agent does action-"+actionId); + int actionId = learner.selectAction(currentStateId); + System.out.println("Agent does action-" + actionId); - int newStateId = random.nextInt(actionCount); - double reward = random.nextDouble(); + int newStateId = random.nextInt(actionCount); + double reward = random.nextDouble(); - System.out.println("Now the new state is "+newStateId); - System.out.println("Agent receives Reward = "+reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Agent receives Reward = " + reward); - System.out.println("World state values changed ..."); - for(int stateId = 0; stateId < stateCount; ++stateId){ - stateValues.set(stateId, random.nextDouble()); - } + System.out.println("World state values changed ..."); + for (int stateId = 0; stateId < stateCount; ++stateId) { + stateValues.set(stateId, random.nextDouble()); + } - learner.update(currentStateId, actionId, newStateId, reward, stateValues::get); - } + learner.update(currentStateId, actionId, newStateId, reward, stateValues::get); + } - ActorCriticLearner learner2 = ActorCriticLearner.fromJson(learner.toJson()); + ActorCriticLearner learner2 = ActorCriticLearner.fromJson(learner.toJson()); - assertThat(learner2.getP()).isEqualTo(learner.getP()); - assertThat(learner2.getActionSelection()).isEqualTo(learner.getActionSelection()); - assertThat(learner2).isEqualTo(learner); - } + assertThat(learner2.getP()).isEqualTo(learner.getP()); + assertThat(learner2.getActionSelection()).isEqualTo(learner.getActionSelection()); + assertThat(learner2).isEqualTo(learner); + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java index ee30dc4..6b5b763 100644 --- a/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java @@ -1,34 +1,33 @@ package com.github.chen0040.rl.learning.models; -import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.JSON; import com.github.chen0040.rl.models.QModel; +import com.google.gson.Gson; + import org.testng.annotations.Test; import static org.assertj.core.api.Java6Assertions.assertThat; public class QModelUnitTest { - @Test - public void testJsonSerialization() { - QModel model = new QModel(100, 10); - - model.setQ(3, 4, 0.3); - model.setQ(92, 2, 0.2); - - model.setAlpha(0.4); - model.setGamma(0.3); - - String json = JSON.toJSONString(model); - QModel model2 = JSON.parseObject(json, QModel.class); + @Test + public void testJsonSerialization() { + QModel model = new QModel(100, 10); - assertThat(model).isEqualTo(model2); - assertThat(model.getQ()).isEqualTo(model2.getQ()); - assertThat(model.getAlphaMatrix()).isEqualTo(model2.getAlphaMatrix()); - assertThat(model.getStateCount()).isEqualTo(model2.getStateCount()); - assertThat(model.getActionCount()).isEqualTo(model2.getActionCount()); - assertThat(model.getGamma()).isEqualTo(model2.getGamma()); + model.setQ(3, 4, 0.3); + model.setQ(92, 2, 0.2); + model.setAlpha(0.4); + model.setGamma(0.3); + String json = new Gson().toJson(model); //JSON.toJSONString(model); + QModel model2 = new Gson().fromJson(json, QModel.class); //JSON.parseObject(json, QModel.class); + assertThat(model).isEqualTo(model2); + assertThat(model.getQ()).isEqualTo(model2.getQ()); + assertThat(model.getAlphaMatrix()).isEqualTo(model2.getAlphaMatrix()); + assertThat(model.getStateCount()).isEqualTo(model2.getStateCount()); + assertThat(model.getActionCount()).isEqualTo(model2.getActionCount()); + assertThat(model.getGamma()).isEqualTo(model2.getGamma()); - } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java index 627d7d0..c7bed39 100644 --- a/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java @@ -1,41 +1,41 @@ package com.github.chen0040.rl.learning.qlearn; - import com.github.chen0040.rl.actionselection.SoftMaxActionSelectionStrategy; + import org.testng.annotations.Test; import java.util.Random; import static org.testng.Assert.*; - /** * Created by xschen on 6/5/2017. */ public class QAgentUnitTest { - @Test - public void test_q_learn(){ - int stateCount = 100; - int actionCount = 10; - QAgent agent = new QAgent(stateCount, actionCount); + @Test + public void test_q_learn() { + int stateCount = 100; + int actionCount = 10; + QAgent agent = new QAgent(stateCount, actionCount); - agent.getLearner().setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName()); + agent.getLearner() + .setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName()); - Random random = new Random(); - agent.start(random.nextInt(stateCount)); - for(int time=0; time < 1000; ++time){ + Random random = new Random(); + agent.start(random.nextInt(stateCount)); + for (int time = 0; time < 1000; ++time) { - int actionId = agent.selectAction().getIndex(); - System.out.println("Agent does action-"+actionId); + int actionId = agent.selectAction().getIndex(); + System.out.println("Agent does action-" + actionId); - int newStateId = random.nextInt(actionCount); - double reward = random.nextDouble(); + int newStateId = random.nextInt(actionCount); + double reward = random.nextDouble(); - System.out.println("Now the new state is "+newStateId); - System.out.println("Agent receives Reward = "+reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Agent receives Reward = " + reward); - agent.update(actionId, newStateId, reward); - } - } + agent.update(actionId, newStateId, reward); + } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java index 6685cac..a0023fb 100644 --- a/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java @@ -1,8 +1,8 @@ package com.github.chen0040.rl.learning.qlearn; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; import org.testng.annotations.Test; import java.util.Random; @@ -10,59 +10,56 @@ import static org.assertj.core.api.Java6Assertions.assertThat; import static org.testng.Assert.*; - /** * Created by xschen on 6/5/2017. */ public class QLearnerUnitTest { - private static final int stateCount = 100; - private static final int actionCount = 10; - - @Test - public void testJsonSerialization() { - - QLearner learner = new QLearner(stateCount, actionCount); + private static final int stateCount = 100; + private static final int actionCount = 10; - run(learner); + @Test + public void testJsonSerialization() { - String json = learner.toJson(); + QLearner learner = new QLearner(stateCount, actionCount); + run(learner); - QLearner learner2 = QLearner.fromJson(json); + String json = learner.toJson(); - assertThat(learner.getModel()).isEqualTo(learner2.getModel()); + QLearner learner2 = QLearner.fromJson(json); - assertThat(learner.getActionSelection()).isEqualTo(learner2.getActionSelection()); + assertThat(learner.getModel()).isEqualTo(learner2.getModel()); - } + assertThat(learner.getActionSelection()).isEqualTo(learner2.getActionSelection()); - @Test - public void test_q_learn(){ + } + @Test + public void test_q_learn() { - QLearner learner = new QLearner(stateCount, actionCount); + QLearner learner = new QLearner(stateCount, actionCount); - run(learner); + run(learner); - } + } - private void run(QLearner learner) { - Random random = new Random(); - int currentStateId = random.nextInt(stateCount); - for(int time=0; time < 1000; ++time){ + private void run(QLearner learner) { + Random random = new Random(); + int currentStateId = random.nextInt(stateCount); + for (int time = 0; time < 1000; ++time) { - int actionId = learner.selectAction(currentStateId).getIndex(); - System.out.println("Controller does action-"+actionId); + int actionId = learner.selectAction(currentStateId).getIndex(); + System.out.println("Controller does action-" + actionId); - int newStateId = random.nextInt(actionCount); - double reward = random.nextDouble(); + int newStateId = random.nextInt(actionCount); + double reward = random.nextDouble(); - System.out.println("Now the new state is "+newStateId); - System.out.println("Controller receives Reward = "+reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Controller receives Reward = " + reward); - learner.update(currentStateId, actionId, newStateId, reward); - currentStateId = newStateId; - } - } + learner.update(currentStateId, actionId, newStateId, reward); + currentStateId = newStateId; + } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java index 6110298..c3f668b 100644 --- a/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.learning.rlearn; - import com.github.chen0040.rl.utils.IndexValue; + import org.testng.annotations.Test; import java.util.Random; @@ -9,41 +9,39 @@ import static org.assertj.core.api.Java6Assertions.assertThat; import static org.testng.Assert.*; - /** * Created by xschen on 6/5/2017. */ public class RAgentUnitTest { - @Test - public void test_r_learn(){ - - int stateCount = 100; - int actionCount = 10; - RAgent agent = new RAgent(stateCount, actionCount); + @Test + public void test_r_learn() { - Random random = new Random(); - agent.start(random.nextInt(stateCount)); - for(int time=0; time < 1000; ++time){ + int stateCount = 100; + int actionCount = 10; + RAgent agent = new RAgent(stateCount, actionCount); - IndexValue actionValue = agent.selectAction(); - int actionId = actionValue.getIndex(); - System.out.println("Agent does action-"+actionId); + Random random = new Random(); + agent.start(random.nextInt(stateCount)); + for (int time = 0; time < 1000; ++time) { - int newStateId = random.nextInt(actionCount); - double reward = random.nextDouble(); + IndexValue actionValue = agent.selectAction(); + int actionId = actionValue.getIndex(); + System.out.println("Agent does action-" + actionId); - System.out.println("Now the new state is "+newStateId); - System.out.println("Agent receives Reward = "+reward); + int newStateId = random.nextInt(actionCount); + double reward = random.nextDouble(); - agent.update(newStateId, reward); - } + System.out.println("Now the new state is " + newStateId); + System.out.println("Agent receives Reward = " + reward); - RLearner learner = agent.getLearner(); - RLearner learner2 = RLearner.fromJson(learner.toJson()); + agent.update(newStateId, reward); + } - assertThat(learner).isEqualTo(learner2); + RLearner learner = agent.getLearner(); + RLearner learner2 = RLearner.fromJson(learner.toJson()); + assertThat(learner).isEqualTo(learner2); - } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java index 0959b33..fa82642 100644 --- a/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java @@ -1,6 +1,5 @@ package com.github.chen0040.rl.learning.sarsa; - import org.testng.annotations.Test; import java.util.Random; @@ -8,39 +7,38 @@ import static org.assertj.core.api.Java6Assertions.assertThat; import static org.testng.Assert.*; - /** * Created by xschen on 6/5/2017. */ public class SarsaAgentUnitTest { - @Test - public void test_sarsa(){ - int stateCount = 100; - int actionCount = 10; - SarsaAgent agent = new SarsaAgent(stateCount, actionCount); + @Test + public void test_sarsa() { + int stateCount = 100; + int actionCount = 10; + SarsaAgent agent = new SarsaAgent(stateCount, actionCount); - double reward = 0; //immediate reward by transiting from prevState to currentState - Random random = new Random(); - agent.start(random.nextInt(stateCount)); - int actionTaken = agent.selectAction().getIndex(); - for(int time=0; time < 1000; ++time){ + double reward = 0; //immediate reward by transiting from prevState to currentState + Random random = new Random(); + agent.start(random.nextInt(stateCount)); + int actionTaken = agent.selectAction().getIndex(); + for (int time = 0; time < 1000; ++time) { - System.out.println("Agent does action-"+actionTaken); + System.out.println("Agent does action-" + actionTaken); - int newStateId = random.nextInt(actionCount); - reward = random.nextDouble(); + int newStateId = random.nextInt(actionCount); + reward = random.nextDouble(); - System.out.println("Now the new state is "+newStateId); - System.out.println("Agent receives Reward = "+reward); + System.out.println("Now the new state is " + newStateId); + System.out.println("Agent receives Reward = " + reward); - agent.update(actionTaken, newStateId, reward); - } + agent.update(actionTaken, newStateId, reward); + } - SarsaLearner learner = agent.getLearner(); + SarsaLearner learner = agent.getLearner(); - SarsaLearner learner2 = SarsaLearner.fromJson(learner.toJson()); + SarsaLearner learner2 = SarsaLearner.fromJson(learner.toJson()); - assertThat(learner2).isEqualTo(learner); - } + assertThat(learner2).isEqualTo(learner); + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java index 2e9cf52..a613388 100644 --- a/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java @@ -1,8 +1,10 @@ package com.github.chen0040.rl.learning.utils; -import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.serializer.SerializerFeature; +//import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.utils.Matrix; +import com.google.gson.Gson; + import org.testng.annotations.Test; import java.util.Random; @@ -11,55 +13,55 @@ public class MatrixUnitTest { - private static final Random random = new Random(42); + private static final Random random = new Random(42); - @Test - public void testJsonSerialization() { - Matrix matrix = new Matrix(10, 10); - matrix.set(0, 0, 10); - matrix.set(4, 2, 2); - matrix.set(3, 3, 2); + @Test + public void testJsonSerialization() { + Matrix matrix = new Matrix(10, 10); + matrix.set(0, 0, 10); + matrix.set(4, 2, 2); + matrix.set(3, 3, 2); - assertThat(matrix.get(0, 0)).isEqualTo(10); - assertThat(matrix.get(4, 2)).isEqualTo(2); - assertThat(matrix.get(3, 3)).isEqualTo(2); - assertThat(matrix.get(4, 4)).isEqualTo(0); + assertThat(matrix.get(0, 0)).isEqualTo(10); + assertThat(matrix.get(4, 2)).isEqualTo(2); + assertThat(matrix.get(3, 3)).isEqualTo(2); + assertThat(matrix.get(4, 4)).isEqualTo(0); - assertThat(matrix.getRowCount()).isEqualTo(10); - assertThat(matrix.getColumnCount()).isEqualTo(10); + assertThat(matrix.getRowCount()).isEqualTo(10); + assertThat(matrix.getColumnCount()).isEqualTo(10); - String json = JSON.toJSONString(matrix, SerializerFeature.PrettyFormat); + String json = new Gson().toJson(matrix); //JSON.toJSONString(matrix, SerializerFeature.PrettyFormat); - System.out.println(json); - Matrix matrix2 = JSON.parseObject(json, Matrix.class); - assertThat(matrix).isEqualTo(matrix2); + System.out.println(json); + Matrix matrix2 = new Gson().fromJson(json, Matrix.class); //JSON.parseObject(json, Matrix.class); + assertThat(matrix).isEqualTo(matrix2); - for(int i=0; i < matrix.getRowCount(); ++i){ - for(int j=0; j < matrix.getColumnCount(); ++j) { - assertThat(matrix.get(i, j)).isEqualTo(matrix2.get(i, j)); - } - } - } + for (int i = 0; i < matrix.getRowCount(); ++i) { + for (int j = 0; j < matrix.getColumnCount(); ++j) { + assertThat(matrix.get(i, j)).isEqualTo(matrix2.get(i, j)); + } + } + } - @Test - public void testJsonSerialization_Random() { - Matrix matrix = new Matrix(10, 10); - for(int i=0; i < matrix.getRowCount(); ++i){ - for(int j=0; j < matrix.getColumnCount(); ++j){ - matrix.set(i, j, random.nextDouble()); - } - } - Matrix matrix2 = matrix.makeCopy(); - assertThat(matrix).isEqualTo(matrix2); + @Test + public void testJsonSerialization_Random() { + Matrix matrix = new Matrix(10, 10); + for (int i = 0; i < matrix.getRowCount(); ++i) { + for (int j = 0; j < matrix.getColumnCount(); ++j) { + matrix.set(i, j, random.nextDouble()); + } + } + Matrix matrix2 = matrix.makeCopy(); + assertThat(matrix).isEqualTo(matrix2); - String json = JSON.toJSONString(matrix); - Matrix matrix3 = JSON.parseObject(json, Matrix.class); - assertThat(matrix2).isEqualTo(matrix3); + String json = new Gson().toJson(matrix); //JSON.toJSONString(matrix); + Matrix matrix3 = new Gson().fromJson(json, Matrix.class); //JSON.parseObject(json, Matrix.class); + assertThat(matrix2).isEqualTo(matrix3); - for(int i=0; i < matrix.getRowCount(); ++i){ - for(int j=0; j < matrix.getColumnCount(); ++j){ - assertThat(matrix2.get(i, j)).isEqualTo(matrix3.get(i, j)); - } - } - } + for (int i = 0; i < matrix.getRowCount(); ++i) { + for (int j = 0; j < matrix.getColumnCount(); ++j) { + assertThat(matrix2.get(i, j)).isEqualTo(matrix3.get(i, j)); + } + } + } } diff --git a/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java index daa4396..40c87bd 100644 --- a/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java +++ b/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java @@ -1,20 +1,22 @@ package com.github.chen0040.rl.learning.utils; -import com.alibaba.fastjson.JSON; +//import com.alibaba.fastjson.JSON; import com.github.chen0040.rl.utils.Vec; +import com.google.gson.Gson; + import org.testng.annotations.Test; import static org.assertj.core.api.Java6Assertions.assertThat; public class VecUnitTest { - @Test - public void testJsonSerialization() { - Vec vec = new Vec(100); - vec.set(9, 100); - vec.set(11, 2); - vec.set(0, 1); - String json = JSON.toJSONString(vec); - Vec vec2 = JSON.parseObject(json, Vec.class); - assertThat(vec).isEqualTo(vec2); - } + @Test + public void testJsonSerialization() { + Vec vec = new Vec(100); + vec.set(9, 100); + vec.set(11, 2); + vec.set(0, 1); + String json = new Gson().toJson(vec); //JSON.toJSONString(vec); + Vec vec2 = new Gson().fromJson(json, Vec.class); //JSON.parseObject(json, Vec.class); + assertThat(vec).isEqualTo(vec2); + } } diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties index ef69b72..98799c4 100644 --- a/src/test/resources/log4j.properties +++ b/src/test/resources/log4j.properties @@ -1,9 +1,7 @@ # Set root logger level to DEBUG and its only appender to A1. log4j.rootLogger=DEBUG, A1 - # A1 is set to be a ConsoleAppender. log4j.appender.A1=org.apache.log4j.ConsoleAppender - # A1 uses PatternLayout. log4j.appender.A1.layout=org.apache.log4j.PatternLayout log4j.appender.A1.layout.ConversionPattern=%-5p %c %x - %m%n From 18b393314e92ed510f1f7fa99b8028f1901424d9 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sat, 18 May 2019 18:31:03 +0200 Subject: [PATCH 2/2] Included Demo TicTacToe as Application (result in logcat) --- .mvn/wrapper/maven-wrapper.properties | 2 +- build.gradle | 13 ++-- demo-tic-tac-toe/build.gradle | 61 +++++++++++++++++++ demo-tic-tac-toe/src/main/AndroidManifest.xml | 2 + settings.gradle | 2 + .../chen0040/rl/learning/rlearn/RLearner.java | 2 +- .../com/github/chen0040/rl/models/QModel.java | 8 +-- .../chen0040/rl/models/UtilityModel.java | 8 +-- .../github/chen0040/rl/utils/IndexValue.java | 8 +-- .../com/github/chen0040/rl/utils/Matrix.java | 8 +-- .../com/github/chen0040/rl/utils/Vec.java | 8 +-- 11 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 demo-tic-tac-toe/build.gradle create mode 100644 demo-tic-tac-toe/src/main/AndroidManifest.xml diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties index 56bb016..b5943ce 100644 --- a/.mvn/wrapper/maven-wrapper.properties +++ b/.mvn/wrapper/maven-wrapper.properties @@ -1 +1 @@ -distributionUrl=https://repo1.maven.org/maven2/org/apache/maven/apache-maven/3.5.0/apache-maven-3.5.0-bin.zip \ No newline at end of file +#distributionUrl=https://repo1.maven.org/maven2/org/apache/maven/apache-maven/3.5.0/apache-maven-3.5.0-bin.zip \ No newline at end of file diff --git a/build.gradle b/build.gradle index 0051418..091f560 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,11 @@ buildscript { repositories { google() jcenter() + maven { + url "https://jitpack.io" + } } + dependencies { classpath 'com.android.tools.build:gradle:3.5.0-beta01' @@ -17,6 +21,9 @@ allprojects { repositories { google() jcenter() + maven { + url "https://jitpack.io" + } } } @@ -64,11 +71,7 @@ android { } dependencies { - testImplementation 'junit:junit:4.12' - implementation 'androidx.appcompat:appcompat:1.0.2' - - compileOnly 'org.projectlombok:lombok:1.18.8' - annotationProcessor 'org.projectlombok:lombok:1.18.8' + testImplementation ('junit:junit:4.12') { exclude module: 'hamcrest-core' } implementation 'com.google.code.gson:gson:2.8.5' implementation 'org.testng:testng:6.9.6' diff --git a/demo-tic-tac-toe/build.gradle b/demo-tic-tac-toe/build.gradle new file mode 100644 index 0000000..80f302e --- /dev/null +++ b/demo-tic-tac-toe/build.gradle @@ -0,0 +1,61 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + + defaultConfig { + minSdkVersion 24 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + + lintOptions { + abortOnError false + } + + buildTypes { + debug { + testCoverageEnabled false + } + } + + sourceSets { + main { + java { + // Merge source sets instead of adding rushcore as submodule so that the test coverage report works + srcDirs = ['src/main/java'] + } + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { +// implementation 'androidx.appcompat:appcompat:1.0.2' + testImplementation ('junit:junit:4.12') { exclude module: 'hamcrest-core' } + modules { + module("org.hamcrest:hamcrest-core") { + replacedBy("junit:junit", "Vous ") + } + } + + implementation ('org.slf4j:slf4j-simple:1.8.0-beta4') { exclude module: 'junit' } + + implementation 'org.testng:testng:6.9.6' + implementation 'org.assertj:assertj-core:3.12.2' + testImplementation project(path: ':') + +// implementation project(path: ':java-reinforcement-learning') + implementation project(path: ':') +} diff --git a/demo-tic-tac-toe/src/main/AndroidManifest.xml b/demo-tic-tac-toe/src/main/AndroidManifest.xml new file mode 100644 index 0000000..e77dbbd --- /dev/null +++ b/demo-tic-tac-toe/src/main/AndroidManifest.xml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 0720936..2a773c9 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1 +1,3 @@ //include ':java-reinforcement-learning' +include ':demo-tic-tac-toe' +project(':demo-tic-tac-toe').projectDir = new File('demo-tic-tac-toe') \ No newline at end of file diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java index 40dea5a..d80edbe 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java @@ -10,7 +10,7 @@ import com.github.chen0040.rl.utils.IndexValue; import com.google.gson.Gson; -import lombok.Getter; +//import lombok.Getter; import java.io.Serializable; import java.util.Set; diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index 22cbde9..e4a61ce 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -4,8 +4,8 @@ import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; +//import lombok.Getter; +//import lombok.Setter; import java.util.*; @@ -13,8 +13,8 @@ * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that * it is different from utility of a state */ -@Getter -@Setter +//@Getter +//@Setter public class QModel { /** * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java index 8188236..e8e2eb4 100644 --- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java +++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java @@ -2,8 +2,8 @@ import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; +//import lombok.Getter; +//import lombok.Setter; import java.io.Serializable; @@ -18,8 +18,8 @@ * T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ * is applied at state $s$ */ -@Getter -@Setter +//@Getter +//@Setter public class UtilityModel implements Serializable { private Vec U; private int stateCount; diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java index 621e7a3..264dcc3 100644 --- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java +++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java @@ -1,13 +1,13 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; +//import lombok.Getter; +//import lombok.Setter; /** * Created by xschen on 6/5/2017. */ -@Getter -@Setter +//@Getter +//@Setter public class IndexValue { private int index; private double value; diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java index 2421a8b..30793b7 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java +++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java @@ -2,8 +2,8 @@ //import com.alibaba.fastjson.annotation.JSONField; -import lombok.Getter; -import lombok.Setter; +//import lombok.Getter; +//import lombok.Setter; import java.io.Serializable; import java.util.ArrayList; @@ -14,8 +14,8 @@ /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter +//@Getter +//@Setter public class Matrix implements Serializable { private Map rows = new HashMap<>(); private int rowCount; diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index 1152464..763005a 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; +//import lombok.Getter; +//import lombok.Setter; import java.io.Serializable; import java.util.HashMap; @@ -12,8 +12,8 @@ /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter +//@Getter +//@Setter public class Vec implements Serializable { private Map data = new HashMap<>(); private int dimension;