-
Notifications
You must be signed in to change notification settings - Fork 834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add support for ContextualBandit in the VW module #896
feat: Add support for ContextualBandit in the VW module #896
Conversation
jackgerrits
commented
Jul 16, 2020
- Adds VowpalWabbitContextualBandit, VowpalWabbitContextualBanditModel, ColumnVectorSequencer classes
- Update com.github.vowpalwabbit dependency version for CB support
- Add tests in Scala and Python for the new functionality
- Other featurizer improvements from @eisber
src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitClassifier.scala
Outdated
Show resolved
Hide resolved
def add_example(p_log: Double, reward: Double, p_pred: Double, count: Int = 1): Unit = { | ||
total_events += count | ||
if (p_pred > 0) { | ||
val p_over_p = p_pred / p_log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marco-rossi29 can you review?
src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitContextualBandit.scala
Outdated
Show resolved
Hide resolved
src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitContextualBandit.scala
Outdated
Show resolved
Hide resolved
src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitContextualBandit.scala
Outdated
Show resolved
Hide resolved
src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitUtil.scala
Outdated
Show resolved
Hide resolved
src/test/scala/com/microsoft/ml/spark/vw/VerifyVowpalWabbitFeaturizer.scala
Outdated
Show resolved
Hide resolved
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
@eisber considering this only really supports cb_adf_explore should I update the class naming to reflect that? Does it seem like an issue that it is scoped to that? |
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! A few minor nits on the tests and I think its ready to roll!
class VowpalWabbitSpec(unittest.TestCase): | ||
def get_data(self): | ||
# create sample data | ||
schema = StructType([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can just pass the col names if you are okay with the standard schema inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems as though by default it makes whole numbers Longs whereas it expects Integers. I can extend the internal conversions to convert Long to Int but then if the number is too large then it will be a conversion exception rather than a schema validation failure? Seems better to limit the schema
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries!
StructField("probability", DoubleType()) | ||
]) | ||
|
||
data = pyspark.sql.SparkSession.builder.getOrCreate().createDataFrame([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spark should already be availible to you because of the imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
StructField("probability", DoubleType()) | ||
]) | ||
|
||
data = pyspark.sql.SparkSession.builder.getOrCreate().createDataFrame([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise here
|
||
def get_data_two_shared(self): | ||
# create sample data | ||
schema = StructType([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
} | ||
} | ||
|
||
class VerifyVowpalWabbitContextualBanditFuzzing extends EstimatorFuzzing[VowpalWabbitContextualBandit] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesent need to be a sepearate class from your COntextual bandit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
/azp run |
Azure Pipelines successfully started running 1 pipeline(s). |