Skip to content

Commit

Permalink
Added support for hashes to Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Nov 13, 2024
1 parent 0891638 commit 9d350ba
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.4.0 (unreleased)

- Added support for hashes and Rover data frames to `predict` method
- Added support for hashes to `Dataset`
- Changed `Dataset` to use column names for feature names with Rover and Daru
- Changed `predict` method to match feature names with Daru
- Dropped support for Ruby < 3.1
Expand Down
9 changes: 9 additions & 0 deletions lib/lightgbm/dataset.rb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def group=(group)
end

def feature_name=(feature_names)
feature_names = feature_names.map(&:to_s)
@feature_names = feature_names
c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size)
# keep reference to string pointers
Expand Down Expand Up @@ -154,6 +155,14 @@ def construct
end
data = data.to_numo
nrow, ncol = data.shape
elsif data.is_a?(Array) && data.first.is_a?(Hash)
keys = data.first.keys
if @feature_name == "auto"
@feature_name = keys
end
nrow = data.count
ncol = data.first.count
flat_data = data.flat_map { |v| v.fetch_values(*keys) }
else
nrow = data.count
ncol = data.first.count
Expand Down
16 changes: 16 additions & 0 deletions test/dataset_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def test_dump_text
assert File.exist?(tempfile)
end

def test_hashes_string_keys
data = [{"x0" => 1, "x1" => 2}, {"x0" => 3, "x1" => 4}, {"x0" => 5, "x1" => 6}]
dataset = LightGBM::Dataset.new(data)
assert_equal 3, dataset.num_data
assert_equal 2, dataset.num_feature
assert_equal ["x0", "x1"], dataset.feature_name
end

def test_hashes_symbol_keys
data = [{x0: 1, x1: 2}, {x0: 3, x1: 4}, {x0: 5, x1: 6}]
dataset = LightGBM::Dataset.new(data)
assert_equal 3, dataset.num_data
assert_equal 2, dataset.num_feature
assert_equal ["x0", "x1"], dataset.feature_name
end

def test_matrix
data = Matrix.build(3, 3) { |row, col| row + col }
label = Vector.elements([4, 5, 6])
Expand Down

0 comments on commit 9d350ba

Please sign in to comment.