Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for ContextualBandit in the VW module #896

Conversation

jackgerrits
Copy link
Member

  • Adds VowpalWabbitContextualBandit, VowpalWabbitContextualBanditModel, ColumnVectorSequencer classes
  • Update com.github.vowpalwabbit dependency version for CB support
  • Add tests in Scala and Python for the new functionality
  • Other featurizer improvements from @eisber

@jackgerrits jackgerrits changed the title Feat: Add support for ContextualBandit in the VW module feat: Add support for ContextualBandit in the VW module Jul 16, 2020
def add_example(p_log: Double, reward: Double, p_pred: Double, count: Int = 1): Unit = {
total_events += count
if (p_pred > 0) {
val p_over_p = p_pred / p_log
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marco-rossi29 can you review?

src/test/python/mmlsparktest/vw/test_vw_cb.py Show resolved Hide resolved
src/test/python/mmlsparktest/vw/test_vw_cb.py Outdated Show resolved Hide resolved
@jackgerrits
Copy link
Member Author

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@jackgerrits
Copy link
Member Author

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@jackgerrits
Copy link
Member Author

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@jackgerrits
Copy link
Member Author

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@jackgerrits
Copy link
Member Author

@eisber considering this only really supports cb_adf_explore should I update the class naming to reflect that? Does it seem like an issue that it is scoped to that?

@jackgerrits
Copy link
Member Author

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

Copy link
Collaborator

@mhamilton723 mhamilton723 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! A few minor nits on the tests and I think its ready to roll!

class VowpalWabbitSpec(unittest.TestCase):
def get_data(self):
# create sample data
schema = StructType([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just pass the col names if you are okay with the standard schema inference

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems as though by default it makes whole numbers Longs whereas it expects Integers. I can extend the internal conversions to convert Long to Int but then if the number is too large then it will be a conversion exception rather than a schema validation failure? Seems better to limit the schema

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries!

StructField("probability", DoubleType())
])

data = pyspark.sql.SparkSession.builder.getOrCreate().createDataFrame([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark should already be availible to you because of the imports

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

StructField("probability", DoubleType())
])

data = pyspark.sql.SparkSession.builder.getOrCreate().createDataFrame([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise here


def get_data_two_shared(self):
# create sample data
schema = StructType([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

}
}

class VerifyVowpalWabbitContextualBanditFuzzing extends EstimatorFuzzing[VowpalWabbitContextualBandit] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesent need to be a sepearate class from your COntextual bandit tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@mhamilton723
Copy link
Collaborator

/azp run

@azure-pipelines
Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@mhamilton723 mhamilton723 merged commit e9d8802 into microsoft:master Sep 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants