description |
---|
A machine learning approach that uses unlabeled data |
PostgresML supports several clustering algorithms for unsupervised learning. Models can be trained using pgml.train
on unlabeled data to identify groups within the data.
To build clusters on a given dataset, we can use the table or a view. Since clustering is an unsupervised algorithm, we don't need a column that represents a label as one of the inputs to pgml.train
.
In pgml.train
you need to set cluster
as task and pass a project_name
. Most parameters are optional.
pgml.train(
project_name TEXT,
task TEXT DEFAULT NULL,
relation_name TEXT DEFAULT NULL,
algorithm TEXT DEFAULT 'linear',
hyperparams JSONB DEFAULT '{}'::JSONB
)
Algorithm | Reference |
---|---|
affinity_propagation |
AffinityPropagation |
birch |
Birch |
kmeans |
K-Means |
mini_batch_kmeans |
MiniBatchKMeans |
This example trains models on the sklean digits dataset -- which is a copy of the test set of the UCI ML hand-written digits datasets. This demonstrates using a table with a single array feature column for clustering. You could do something similar with a vector column.
SELECT pgml.load_dataset('digits');
-- create an unlabeled table of the images for unsupervised learning
CREATE VIEW pgml.digit_vectors AS
SELECT image FROM pgml.digits;
-- view the dataset
SELECT left(image::text, 40) || ',...}' FROM pgml.digit_vectors LIMIT 10;
-- train a simple model to classify the data
SELECT * FROM pgml.train('Handwritten Digit Clusters', 'cluster', 'pgml.digit_vectors', hyperparams => '{"n_clusters": 10}');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digit Clusters', image) AS prediction
FROM pgml.digits
LIMIT 10;
SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'affinity_propagation');
SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'birch', hyperparams => '{"n_clusters": 10}');
SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'kmeans', hyperparams => '{"n_clusters": 10}');
SELECT * FROM pgml.train('Handwritten Digit Clusters', algorithm => 'mini_batch_kmeans', hyperparams => '{"n_clusters": 10}');