Skip to content

Commit

Permalink
switch to kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
yangarbiter committed Jan 16, 2017
1 parent 6f53edb commit fbdd3b3
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions examples/alce_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import os

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit
import sklearn.datasets
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR

# libact classes
from libact.base.dataset import Dataset, import_libsvm_sparse
from libact.models import LogisticRegression
from libact.models import SVM, LogisticRegression
from libact.query_strategies.multiclass import ActiveLearningWithCostEmbedding as ALCE
from libact.query_strategies import UncertaintySampling, RandomSampling
from libact.labelers import IdealLabeler
Expand Down Expand Up @@ -75,34 +76,34 @@ def main():
cost_matrix = np.random.RandomState(1126).rand(n_classes, n_classes)
np.fill_diagonal(cost_matrix, 0)

quota = 500 # number of samples to query
quota = 300 # number of samples to query

# Comparing UncertaintySampling strategy with RandomSampling.
# model is the base learner, e.g. LogisticRegression, SVM ... etc.
qs = UncertaintySampling(trn_ds, method='lc', model=LogisticRegression())
model = LogisticRegression()
qs = UncertaintySampling(trn_ds, method='lc', model=SVM())
model = SVM()
E_in_1, E_out_1 = run(trn_ds, tst_ds, lbr, model, qs, quota, cost_matrix)

qs2 = RandomSampling(trn_ds2)
model = LogisticRegression()
model = SVM()
E_in_2, E_out_2 = run(trn_ds2, tst_ds, lbr, model, qs2, quota, cost_matrix)

qs3 = ALCE(trn_ds3, cost_matrix, LinearRegression())
model = LogisticRegression()
qs3 = ALCE(trn_ds3, cost_matrix, SVR())
model = SVM()
E_in_3, E_out_3 = run(trn_ds3, tst_ds, lbr, model, qs3, quota, cost_matrix)

print("Uncertainty: ", E_out_1[::20].tolist())
print("Random: ", E_out_2[::20].tolist())
print("ALCE: ", E_out_3[::20].tolist())

query_num = np.arange(1, quota + 1)
plt.plot(query_num, E_out_1[0], 'g', label='Uncertainty sampling')
plt.plot(query_num, E_out_2[1], 'k', label='Random')
plt.plot(query_num, E_out_3[2], 'r', label='ALCE')
plt.plot(query_num, E_out_1, 'g', label='Uncertainty sampling')
plt.plot(query_num, E_out_2, 'k', label='Random')
plt.plot(query_num, E_out_3, 'r', label='ALCE')
plt.xlabel('Number of Queries')
plt.ylabel('Error')
plt.title('Experiment Result')
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.5),
fancybox=True, shadow=True, ncol=5)
plt.show()

Expand Down

0 comments on commit fbdd3b3

Please sign in to comment.