Skip to content

Commit

Permalink
bug fixes and further work on logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
zenogantner committed May 25, 2011
1 parent f4604cb commit 7aa0e98
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
2 changes: 1 addition & 1 deletion example_data
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
3 1:1 2:2
2 1:2 2:3
7 1:4 2:1
1 1:5 2:5
1 1:5 2:5
2 changes: 1 addition & 1 deletion linear_regression.pl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
else {
my $num_test_instances = (dims $test_instances)[0];

my $test_rss = sum(($test_pred - $targets) ** 2);
my $test_rss = sum(($test_pred - $test_targets) ** 2);
my $test_rmse = sqrt($test_rss / $num_test_instances);
say "RMSE $test_rmse N $num_test_instances";
}
Expand Down
50 changes: 29 additions & 21 deletions logistic_regression.pl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#!/usr/bin/perl

# Machine learning examples
# Logistic regression example

# Get example datasets for regression and classification with
# Get example dataset with
# wget http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/heart_scale

# (c) 2011 Zeno Gantner
# License: GPL

# TODO:
# - create evaluation and prediction subroutines
# - handle arbitrary two-class and multi-class problems
# - shrinkage
# - --verbose
# - SVM
# - internal CV
# - move shared code into module
# - load/save model

use strict;
use warnings;
Expand All @@ -33,11 +35,11 @@
'training-file=s' => \(my $training_file = ''),
'test-file=s' => \(my $test_file = ''),
'prediction-file=s' => \(my $prediction_file = ''),
'probabilities' => \(my $probabilities = 0),
) or usage(-1);

usage(0) if $help;

#$shrinkage += 0.0; # workaround for PDL or Getopt::Long bug (?)

if ($training_file eq '') {
say "Please give --training-file=FILE";
Expand All @@ -48,40 +50,41 @@

my $params = irls($instances, $targets);

# compute RSS and RMSE
# TODO compute accuracy
# compute accuracy
if ($compute_fit) {
my $num_instances = (dims $instances)[0];

my $pred = $params->transpose x $instances; # parentheses or OO notation are important here
my $rss = sum(($pred - $targets) ** 2);
my $rmse = sqrt($rss / $num_instances);
say "RSS $rss FIT_RMSE $rmse N $num_instances";
my $prob = 1 / (1 + exp(-1 * ($params->transpose x $instances) ));
my $pred = $prob > 0.5;

my $fit_err = sum(abs($pred - $targets));
$fit_err /= $num_instances;

say "FIT_ERR $fit_err N $num_instances";
}

# test/write out predictions
# TODO write out decisions
if ($test_file) {
my ( $test_instances, $test_targets ) = convert_to_pdl(read_data($test_file));
my $test_pred = $params->transpose x $test_instances;
my $test_prob = 1 / (1 + exp(-1 * ($params->transpose x $test_instances) ));
my $test_pred = $test_prob > 0.5;

if ($prediction_file) {
write_vector($test_pred, $prediction_file);
write_vector($probabilities ? $test_prob : $test_pred, $prediction_file);
}
else {
my $num_test_instances = (dims $test_instances)[0];

my $test_rss = sum(($test_pred - $targets) ** 2);
my $test_rmse = sqrt($test_rss / $num_test_instances);
say "RMSE $test_rmse N $num_test_instances";
my $test_err = sum(abs($test_pred - $test_targets));
$test_err /= $num_test_instances;
say "ERR $test_err N $num_test_instances";
}
}

exit 0;

# compute logistic regression parameters using iteratively reweighted least squares (IRLS)
sub irls {
# TODO add regularization
my ($instances, $targets) = @_;

my $num_instances = (dims $instances)[0];
Expand Down Expand Up @@ -153,7 +156,10 @@ sub read_data {
chomp $line;

my @tokens = split /\s+/, $line;
my $label = shift @tokens;
my $label = shift @tokens;
$label = 0 if $label == -1;

die "Label must be 1/0/-1, but is $label\n" if $label != 0 && $label != 1;

my %feature_value = map { split /:/ } @tokens;
$num_features = List::Util::max(keys %feature_value, $num_features);
Expand Down Expand Up @@ -182,15 +188,17 @@ sub usage {
print << "END";
$PROGRAM_NAME
Perl Data Language ridge regression example
Perl Data Language logistic regression example
usage: $PROGRAM_NAME [OPTIONS] [INPUT]
--help display this usage information
--compute-fit compute RSS and RMSE on training data
--epsilon=NUM set convergence sensitivity to NUM
--compute-fit compute error on training data
--training-file=FILE read training data from FILE
--test-file=FILE evaluate on FILE
--prediction-file=FILE write predictions for instances in the test file to FILE
--probabilties write out probabilties instead of decisions
END
exit $return_code;
}

0 comments on commit 7aa0e98

Please sign in to comment.