Skip to content

Commit 2fedeb5

Browse files
authored
Making changes in InternalValueCount and InternalAvg to support Scripted metric in reducing aggregation (#18411)
* Making changes in InternalValueCount and InternalAvg to support Scripted metric in reducing agggregation Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Adding tests for supporting Scripted Aggregation in reduce for average and value count aggregations Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Fixing build issues Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Addressed comments around logging and tests Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Making changes to CHANGELOG.md Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Adding a new class ScriptedAvg for reduce Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Adding changes to CHANGELOG Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Adding javadocs Signed-off-by: Kshitij Tandon <tandonks@amazon.com> * Fixing violations Signed-off-by: Kshitij Tandon <tandonks@amazon.com> --------- Signed-off-by: Kshitij Tandon <tandonks@amazon.com>
1 parent cc20ac7 commit 2fedeb5

File tree

6 files changed

+298
-4
lines changed

6 files changed

+298
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5151
- Optimize Composite Aggregations by removing unnecessary object allocations ([#18531](https://github.com/opensearch-project/OpenSearch/pull/18531))
5252
- [Star-Tree] Add search support for ip field type ([#18671](https://github.com/opensearch-project/OpenSearch/pull/18671))
5353
- [Derived Source] Add integration of derived source feature across various paths like get/search/recovery ([#18565](https://github.com/opensearch-project/OpenSearch/pull/18565))
54+
- Supporting Scripted Metric Aggregation when reducing aggregations in InternalValueCount and InternalAvg ([18411](https://github.com/opensearch-project/OpenSearch/pull18411)))
5455

5556
### Changed
5657
- Update Subject interface to use CheckedRunnable ([#18570](https://github.com/opensearch-project/OpenSearch/issues/18570))

server/src/main/java/org/opensearch/search/aggregations/metrics/InternalAvg.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,28 @@ public InternalAvg reduce(List<InternalAggregation> aggregations, ReduceContext
109109
// Compute the sum of double values with Kahan summation algorithm which is more
110110
// accurate than naive summation.
111111
for (InternalAggregation aggregation : aggregations) {
112-
InternalAvg avg = (InternalAvg) aggregation;
113-
count += avg.count;
114-
kahanSummation.add(avg.sum);
112+
if (aggregation instanceof InternalScriptedMetric) {
113+
// If using InternalScriptedMetric in place of InternalAvg
114+
Object value = ((InternalScriptedMetric) aggregation).aggregation();
115+
if (value instanceof ScriptedAvg scriptedAvg) {
116+
count += scriptedAvg.getCount();
117+
kahanSummation.add(scriptedAvg.getSum());
118+
} else {
119+
throw new IllegalArgumentException(
120+
"Invalid ScriptedMetric result for ["
121+
+ getName()
122+
+ "] avg aggregation. Expected ScriptedAvg "
123+
+ "but received ["
124+
+ (value == null ? "null" : value.getClass().getName())
125+
+ "]"
126+
);
127+
}
128+
} else {
129+
// Original handling for InternalAvg
130+
InternalAvg avg = (InternalAvg) aggregation;
131+
count += avg.count;
132+
kahanSummation.add(avg.sum);
133+
}
115134
}
116135
return new InternalAvg(getName(), kahanSummation.value(), count, format, getMetadata());
117136
}

server/src/main/java/org/opensearch/search/aggregations/metrics/InternalValueCount.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,24 @@ public double value() {
8686
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
8787
long valueCount = 0;
8888
for (InternalAggregation aggregation : aggregations) {
89-
valueCount += ((InternalValueCount) aggregation).value;
89+
if (aggregation instanceof InternalScriptedMetric) {
90+
// If using InternalScriptedMetric in place of InternalValueCount
91+
Object value = ((InternalScriptedMetric) aggregation).aggregation();
92+
if (value instanceof Number) {
93+
valueCount += ((Number) value).longValue();
94+
} else {
95+
throw new IllegalArgumentException(
96+
"Invalid ScriptedMetric result for ["
97+
+ getName()
98+
+ "] valueCount aggregation. Expected numeric value from ScriptedMetric aggregation but got ["
99+
+ (value == null ? "null" : value.getClass().getName())
100+
+ "]"
101+
);
102+
}
103+
} else {
104+
// Original handling for InternalValueCount
105+
valueCount += ((InternalValueCount) aggregation).value;
106+
}
90107
}
91108
return new InternalValueCount(name, valueCount, getMetadata());
92109
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
/*
10+
* Licensed to Elasticsearch under one or more contributor
11+
* license agreements. See the NOTICE file distributed with
12+
* this work for additional information regarding copyright
13+
* ownership. Elasticsearch licenses this file to you under
14+
* the Apache License, Version 2.0 (the "License"); you may
15+
* not use this file except in compliance with the License.
16+
* You may obtain a copy of the License at
17+
*
18+
* http://www.apache.org/licenses/LICENSE-2.0
19+
*
20+
* Unless required by applicable law or agreed to in writing,
21+
* software distributed under the License is distributed on an
22+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
23+
* KIND, either express or implied. See the License for the
24+
* specific language governing permissions and limitations
25+
* under the License.
26+
*/
27+
28+
/*
29+
* Modifications Copyright OpenSearch Contributors. See
30+
* GitHub history for details.
31+
*/
32+
33+
package org.opensearch.search.aggregations.metrics;
34+
35+
import org.opensearch.core.common.io.stream.StreamInput;
36+
import org.opensearch.core.common.io.stream.StreamOutput;
37+
import org.opensearch.core.common.io.stream.Writeable;
38+
39+
import java.io.IOException;
40+
41+
/**
42+
* Represents a scripted average calculation containing a sum and count.
43+
*
44+
* @opensearch.internal
45+
*/
46+
public class ScriptedAvg implements Writeable {
47+
private double sum;
48+
private long count;
49+
50+
/**
51+
* Constructor for ScriptedAvg
52+
*
53+
* @param sum The sum of values
54+
* @param count The count of values
55+
*/
56+
public ScriptedAvg(double sum, long count) {
57+
this.sum = sum;
58+
this.count = count;
59+
}
60+
61+
/**
62+
* Read from a stream.
63+
*/
64+
public ScriptedAvg(StreamInput in) throws IOException {
65+
this.sum = in.readDouble();
66+
this.count = in.readLong();
67+
}
68+
69+
@Override
70+
public void writeTo(StreamOutput out) throws IOException {
71+
out.writeDouble(sum);
72+
out.writeLong(count);
73+
}
74+
75+
public double getSum() {
76+
return sum;
77+
}
78+
79+
public long getCount() {
80+
return count;
81+
}
82+
}

server/src/test/java/org/opensearch/search/aggregations/metrics/InternalAvgTests.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
import java.util.List;
4343
import java.util.Map;
4444

45+
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.when;
47+
4548
public class InternalAvgTests extends InternalAggregationTestCase<InternalAvg> {
4649

4750
@Override
@@ -113,6 +116,102 @@ protected void assertFromXContent(InternalAvg avg, ParsedAggregation parsedAggre
113116
}
114117
}
115118

119+
public void testReduceWithScriptedMetric() {
120+
String name = "test_scripted_metric";
121+
DocValueFormat formatter = randomNumericDocValueFormat();
122+
List<InternalAggregation> aggregations = new ArrayList<>();
123+
124+
// Add regular InternalAvg
125+
aggregations.add(new InternalAvg(name, 50.0, 10L, formatter, null));
126+
127+
// Add ScriptedMetric with ScriptedAvg object
128+
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
129+
when(scriptedMetric1.getName()).thenReturn(name);
130+
ScriptedAvg scriptedAvg = new ScriptedAvg(100.0, 20L);
131+
when(scriptedMetric1.aggregation()).thenReturn(scriptedAvg);
132+
aggregations.add(scriptedMetric1);
133+
134+
InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
135+
InternalAvg reduced = avg.reduce(aggregations, null);
136+
137+
// Expected values:
138+
// From InternalAvg: sum=50.0, count=10
139+
// From scriptedMetric1: sum=100.0, count=20
140+
// Total: sum=150.0, count=30
141+
assertEquals(30L, reduced.getCount());
142+
assertEquals(150.0, reduced.getSum(), 0.0000001);
143+
assertEquals(5.0, reduced.getValue(), 0.0000001); // 150/30
144+
}
145+
146+
public void testReduceWithInternalAvgAggregation() {
147+
String name = "test_avg";
148+
DocValueFormat formatter = randomNumericDocValueFormat();
149+
List<InternalAggregation> aggregations = new ArrayList<>();
150+
151+
// Add multiple InternalAvg aggregations
152+
aggregations.add(new InternalAvg(name, 50.0, 10L, formatter, null));
153+
aggregations.add(new InternalAvg(name, 100.0, 20L, formatter, null));
154+
aggregations.add(new InternalAvg(name, 150.0, 30L, formatter, null));
155+
156+
InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
157+
InternalAvg reduced = avg.reduce(aggregations, null);
158+
159+
// Expected values:
160+
// sum = 50.0 + 100.0 + 150.0 = 300.0
161+
// count = 10 + 20 + 30 = 60
162+
assertEquals(60L, reduced.getCount());
163+
assertEquals(300.0, reduced.getSum(), 0.0000001);
164+
assertEquals(5.0, reduced.getValue(), 0.0000001); // 300/60
165+
}
166+
167+
public void testReduceWithScriptedMetricInvalidType() {
168+
String name = "test_scripted_metric";
169+
DocValueFormat formatter = randomNumericDocValueFormat();
170+
List<InternalAggregation> aggregations = new ArrayList<>();
171+
172+
// Add regular InternalAvg
173+
aggregations.add(new InternalAvg(name, 50.0, 10L, formatter, null));
174+
175+
// Add ScriptedMetric with invalid return type (String instead of double[])
176+
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
177+
when(scriptedMetric1.getName()).thenReturn(name);
178+
when(scriptedMetric1.aggregation()).thenReturn("invalid_type");
179+
aggregations.add(scriptedMetric1);
180+
181+
InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
182+
183+
// Expect an IllegalArgumentException when reducing with invalid type
184+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> avg.reduce(aggregations, null));
185+
assertEquals(
186+
"Invalid ScriptedMetric result for [test_scripted_metric] avg aggregation. Expected ScriptedAvg but received [java.lang.String]",
187+
e.getMessage()
188+
);
189+
}
190+
191+
public void testReduceWithScriptedMetricInvalidArrayLength() {
192+
String name = "test_scripted_metric";
193+
DocValueFormat formatter = randomNumericDocValueFormat();
194+
List<InternalAggregation> aggregations = new ArrayList<>();
195+
196+
// Add regular InternalAvg
197+
aggregations.add(new InternalAvg(name, 50.0, 10L, formatter, null));
198+
199+
// Add ScriptedMetric with double array of wrong length (should be 2)
200+
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
201+
when(scriptedMetric.getName()).thenReturn(name);
202+
when(scriptedMetric.aggregation()).thenReturn(new double[] { 100.0, 20.0, 30.0 }); // length 3 instead of 2
203+
aggregations.add(scriptedMetric);
204+
205+
InternalAvg avg = new InternalAvg(name, 0.0, 0L, formatter, null);
206+
207+
// Expect an IllegalArgumentException when reducing with invalid array length
208+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> avg.reduce(aggregations, null));
209+
assertEquals(
210+
"Invalid ScriptedMetric result for [test_scripted_metric] avg aggregation. Expected ScriptedAvg but received [[D]",
211+
e.getMessage()
212+
);
213+
}
214+
116215
@Override
117216
protected InternalAvg mutateInstance(InternalAvg instance) {
118217
String name = instance.getName();

server/src/test/java/org/opensearch/search/aggregations/metrics/InternalValueCountTests.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,18 @@
3232

3333
package org.opensearch.search.aggregations.metrics;
3434

35+
import org.opensearch.search.aggregations.InternalAggregation;
3536
import org.opensearch.search.aggregations.ParsedAggregation;
3637
import org.opensearch.test.InternalAggregationTestCase;
3738

39+
import java.util.ArrayList;
3840
import java.util.HashMap;
3941
import java.util.List;
4042
import java.util.Map;
4143

44+
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.when;
46+
4247
public class InternalValueCountTests extends InternalAggregationTestCase<InternalValueCount> {
4348

4449
@Override
@@ -57,6 +62,77 @@ protected void assertFromXContent(InternalValueCount valueCount, ParsedAggregati
5762
assertEquals(valueCount.getValueAsString(), ((ParsedValueCount) parsedAggregation).getValueAsString());
5863
}
5964

65+
public void testReduceWithScriptedMetric() {
66+
String name = "test_scripted_metric";
67+
List<InternalAggregation> aggregations = new ArrayList<>();
68+
69+
// Add regular InternalValueCount
70+
aggregations.add(new InternalValueCount(name, 50L, null));
71+
72+
// Add ScriptedMetric with Long value
73+
InternalScriptedMetric scriptedMetric1 = mock(InternalScriptedMetric.class);
74+
when(scriptedMetric1.aggregation()).thenReturn(20L);
75+
aggregations.add(scriptedMetric1);
76+
77+
// Add ScriptedMetric with Integer value
78+
InternalScriptedMetric scriptedMetric2 = mock(InternalScriptedMetric.class);
79+
when(scriptedMetric2.aggregation()).thenReturn(30);
80+
aggregations.add(scriptedMetric2);
81+
82+
// Add ScriptedMetric with Double value
83+
InternalScriptedMetric scriptedMetric3 = mock(InternalScriptedMetric.class);
84+
when(scriptedMetric3.aggregation()).thenReturn(10.5);
85+
aggregations.add(scriptedMetric3);
86+
87+
InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
88+
InternalValueCount reduced = (InternalValueCount) valueCount.reduce(aggregations, null);
89+
90+
// Expected: 50 + 20 + 30 + 10 = 110
91+
assertEquals(110L, reduced.getValue());
92+
}
93+
94+
public void testReduceWithInternalValueCountOnly() {
95+
String name = "test_value_count";
96+
List<InternalAggregation> aggregations = new ArrayList<>();
97+
98+
// Add multiple InternalValueCount aggregations
99+
aggregations.add(new InternalValueCount(name, 50L, null));
100+
aggregations.add(new InternalValueCount(name, 30L, null));
101+
aggregations.add(new InternalValueCount(name, 20L, null));
102+
103+
InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
104+
InternalValueCount reduced = (InternalValueCount) valueCount.reduce(aggregations, null);
105+
106+
// Expected: 50 + 30 + 20 = 100
107+
assertEquals(100L, reduced.getValue());
108+
}
109+
110+
public void testReduceWithScriptedMetricInvalidValue() {
111+
String name = "test_scripted_metric";
112+
List<InternalAggregation> aggregations = new ArrayList<>();
113+
114+
// Add regular InternalValueCount
115+
aggregations.add(new InternalValueCount(name, 50L, null));
116+
117+
// Add ScriptedMetric with invalid value type (String instead of Number)
118+
InternalScriptedMetric scriptedMetric = mock(InternalScriptedMetric.class);
119+
when(scriptedMetric.aggregation()).thenReturn("invalid_value");
120+
aggregations.add(scriptedMetric);
121+
122+
InternalValueCount valueCount = new InternalValueCount(name, 0L, null);
123+
124+
// Expect an IllegalArgumentException when reducing with invalid value type
125+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> valueCount.reduce(aggregations, null));
126+
127+
assertEquals(
128+
"Invalid ScriptedMetric result for ["
129+
+ name
130+
+ "] valueCount aggregation. "
131+
+ "Expected numeric value from ScriptedMetric aggregation but got [java.lang.String]",
132+
e.getMessage()
133+
);
134+
}
135+
60136
@Override
61137
protected InternalValueCount mutateInstance(InternalValueCount instance) {
62138
String name = instance.getName();

0 commit comments

Comments
 (0)