-
-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Multi-Task ElasticNet support #238
Conversation
This reverts commit d992c2c.
Running the new multi-task example gives the following output:
The variance looks pretty high, but I'm not sure if that's an issue. |
Codecov ReportBase: 38.68% // Head: 38.59% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #238 +/- ##
==========================================
- Coverage 38.68% 38.59% -0.10%
==========================================
Files 93 93
Lines 6087 6223 +136
==========================================
+ Hits 2355 2402 +47
- Misses 3732 3821 +89
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 can you point me to the test with high variance?
penalty: F, | ||
) -> (Array2<F>, F, u32) { | ||
let n_samples = F::cast(x.shape()[0]); | ||
let n_features = x.shape()[1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using ncols
/ nrows
is a bit more expressive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the current naming better expresses the properties of the dataset (# of rows = # of samples, # of cols = # of features). Plus this naming convention is used basically everywhere in this crate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah I meant the method name, not the variable
#[derive(Clone, Debug, PartialEq)] | ||
pub struct ElasticNetValidParams<F> { | ||
pub struct ElasticNetValidParamsBase<F, const MULTI_TASK: bool> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the multi task flag not be derived from the dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should be able to encapsulate multi-task use case in another Fit
impl on the same type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tried to do that and it didn't work. The single and multitarget Fit
impls bound the target data type with AsSingleTarget
and AsMultiTarget
respectively. If I put both impls on a unified param type then I get a conflicting impl error, even though AsSingleTarget
and AsMultiTarget
are implemented on completely different types. This idea would be doable if we made the target types Array1
and Array2
instead of generics bounded by traits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the conflicting bounds are probably introduced here: https://github.com/rust-ml/linfa/blob/master/src/dataset/impl_targets.rs#L25-L26 have you also tried bounding the type with T: AsTargets<Ix = Ix?>
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried it. Still the same error. Rust compiler probably isn't smart enough to realize that those two bounds are completely disjoint.
It's not a test, but the new example I added |
reviewed the example and made two changes to make the explained variance more usable
--- a/algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
+++ b/algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
@@ -3,7 +3,7 @@ use linfa_elasticnet::{MultiTaskElasticNet, Result};
fn main() -> Result<()> {
// load Diabetes dataset
- let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.90);
+ let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80);
// train pure LASSO model with 0.1 penalty
let model = MultiTaskElasticNet::params()
@@ -18,7 +18,7 @@ fn main() -> Result<()> {
// validate
let y_est = model.predict(&valid);
- println!("predicted variance: {}", valid.r2(&y_est)?);
+ println!("predicted variance: {}", y_est.r2(&valid)?);
Ok(())
} which gives
so worse than taking the average, but the dataset is really small 😅 |
Continuation of #194