Skip to content
This repository has been archived by the owner on Jan 19, 2022. It is now read-only.

Spanner struct pojo query arg #826

Merged
merged 3 commits into from
Jul 9, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private Mutation saveObject(Op op, Object object, Set<String> includeColumns) {
.getPersistentEntity(object.getClass());
Mutation.WriteBuilder writeBuilder = writeBuilder(op,
persistentEntity.tableName());
this.spannerEntityProcessor.write(object, writeBuilder, includeColumns);
this.spannerEntityProcessor.write(object, writeBuilder::set, includeColumns);
return writeBuilder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import com.google.cloud.spanner.ReadOnlyTransaction;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Struct.Builder;
import com.google.cloud.spanner.TimestampBound;
import com.google.cloud.spanner.TransactionContext;
import com.google.cloud.spanner.TransactionRunner.TransactionCallable;
Expand Down Expand Up @@ -139,7 +141,11 @@ public <T> List<T> query(Class<T> entityClass, String sql, List<String> tags,
}
return this.spannerEntityProcessor.mapToList(
executeQuery(SpannerStatementQueryExecutor
.buildStatementFromSqlWithArgs(finalSql, tags, params), options),
.buildStatementFromSqlWithArgs(finalSql, tags, param -> {
Builder builder = Struct.newBuilder();
this.spannerEntityProcessor.write(param, builder::set);
return builder.build();
}, params), options),
entityClass, Optional.empty(), allowPartialRead);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Set;

import com.google.cloud.spanner.Key;
import com.google.cloud.spanner.Mutation.WriteBuilder;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.Struct;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -176,15 +175,16 @@ private boolean canHandlePropertyTypeForArrayWrite(Class type,
/**
* Writes each of the source properties to the sink.
* @param source entity to be written
* @param sink the stateful {@link WriteBuilder} as a target for writing.
* @param sink the stateful multiple-value-binder as a target for writing.
*/
@Override
public void write(Object source, WriteBuilder sink) {
public void write(Object source, MultipleValueBinder sink) {
this.entityWriter.write(source, sink);
}

@Override
public void write(Object source, WriteBuilder sink, Set<String> includeColumns) {
public void write(Object source, MultipleValueBinder sink,
Set<String> includeColumns) {
this.entityWriter.write(source, sink, includeColumns);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Key;
import com.google.cloud.spanner.Mutation.WriteBuilder;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.ValueBinder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -102,6 +103,7 @@ private static Map<Class<?>, BiConsumer<ValueBinder<?>, Iterable>> createIterabl
(BiFunction<ValueBinder, boolean[], ?>) ValueBinder::toBoolArray);
builder.put(long[].class,
(BiFunction<ValueBinder, long[], ?>) ValueBinder::toInt64Array);
builder.put(Struct.class, (BiFunction<ValueBinder, Struct, ?>) ValueBinder::to);

singleItemType2ToMethodMap = builder.build();
}
Expand All @@ -117,7 +119,7 @@ private static Map<Class<?>, BiConsumer<ValueBinder<?>, Iterable>> createIterabl
}

@Override
public void write(Object source, WriteBuilder sink) {
public void write(Object source, MultipleValueBinder sink) {
write(source, sink, null);
}

Expand All @@ -128,7 +130,8 @@ public void write(Object source, WriteBuilder sink) {
* @param includeColumns the properties/columns to write. If null, then all columns are
* written.
*/
public void write(Object source, WriteBuilder sink, Set<String> includeColumns) {
public void write(Object source, MultipleValueBinder sink,
Set<String> includeColumns) {
boolean writeAllColumns = includeColumns == null;
SpannerPersistentEntity<?> persistentEntity = this.spannerMappingContext
.getPersistentEntity(source.getClass());
Expand Down Expand Up @@ -221,7 +224,8 @@ private boolean isValidSpannerKeyType(Class type) {
*/
// @formatter:on
@SuppressWarnings("unchecked")
private void writeProperty(WriteBuilder sink, PersistentPropertyAccessor accessor,
private void writeProperty(MultipleValueBinder sink,
PersistentPropertyAccessor accessor,
SpannerPersistentProperty property) {
Object propertyValue = accessor.getProperty(property);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2018 original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.cloud.gcp.data.spanner.core.convert;

import com.google.cloud.spanner.ValueBinder;

/**
* An interface that allows multiple values to be bound for Cloud Spanner.
*
* @author Chengyuan Zhao
*/
public interface MultipleValueBinder {

/**
* Returns a {@link ValueBinder} for a given field name to bind.
* @param fieldName the name of the field to bind.
* @return a value-binder object that then accepts the value to bind.
*/
ValueBinder set(String fieldName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
import java.util.Set;

import com.google.cloud.spanner.Key;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.Mutation.WriteBuilder;

import org.springframework.data.convert.EntityWriter;

/**
* @author Chengyuan Zhao
* @author Balint Pato
*/
public interface SpannerEntityWriter extends EntityWriter<Object, Mutation.WriteBuilder> {
public interface SpannerEntityWriter extends EntityWriter<Object, MultipleValueBinder> {

/**
* Writes an object's properties to the sink.
Expand All @@ -37,7 +35,7 @@ public interface SpannerEntityWriter extends EntityWriter<Object, Mutation.Write
* @param includeColumns the properties/columns to write. If null, then all columns are
* written.
*/
void write(Object source, WriteBuilder sink, Set<String> includeColumns);
void write(Object source, MultipleValueBinder sink, Set<String> includeColumns);

Key writeToKey(Object key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.function.Function;

import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.ValueBinder;

import org.springframework.cloud.gcp.data.spanner.core.SpannerOperations;
Expand Down Expand Up @@ -66,13 +67,16 @@ public static <T> List<T> executeQuery(Class<T> type, PartTree tree, Object[] pa
Pair<String, List<String>> sqlAndTags = buildPartTreeSqlString(tree,
spannerMappingContext, type);
return spannerOperations.query(type, buildStatementFromSqlWithArgs(
sqlAndTags.getFirst(), sqlAndTags.getSecond(), params));
sqlAndTags.getFirst(), sqlAndTags.getSecond(), null, params));
}

/**
* Creates a Spanner statement.
* @param sql the SQL string with tags.
* @param tags the tags that appear in the SQL string.
* @param paramStructConvertFunc a function to use to convert params to {@link Struct}
* objects if they cannot be directly mapped to Cloud Spanner supported param types.
* If null then this last-attempt conversion is skipped.
* @param params the parameters to substitute the tags. The ordering must be the same
* as the tags.
* @return an SQL statement ready to use with Spanner.
Expand All @@ -81,7 +85,7 @@ public static <T> List<T> executeQuery(Class<T> type, PartTree tree, Object[] pa
*/
@SuppressWarnings("unchecked")
public static Statement buildStatementFromSqlWithArgs(String sql, List<String> tags,
Object[] params) {
Function<Object, Struct> paramStructConvertFunc, Object[] params) {
if (tags == null && params == null) {
return Statement.of(sql);
}
Expand All @@ -95,11 +99,26 @@ public static Statement buildStatementFromSqlWithArgs(String sql, List<String> t
// @formatter:off
BiFunction<ValueBinder, Object, ?> toMethod = (BiFunction<ValueBinder, Object, ?>)
ConverterAwareMappingSpannerEntityWriter.singleItemType2ToMethodMap
.get(param.getClass());
.get(Struct.class.isAssignableFrom(param.getClass()) ? Struct.class : param.getClass());
// @formatter:on
if (toMethod == null) {
throw new IllegalArgumentException("Param: " + param.toString()
+ " is not a supported type: " + param.getClass());
// try to convert the param object into a Struct
String errorMessage = "Param: " + param.toString()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to create this String object only if it will be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

+ " is not a supported type: " + param.getClass();
if (paramStructConvertFunc == null) {
throw new IllegalArgumentException(errorMessage);
}
try {
// @formatter:off
toMethod = (BiFunction<ValueBinder, Object, ?>)
ConverterAwareMappingSpannerEntityWriter.singleItemType2ToMethodMap
.get(Struct.class);
// @formatter:on
param = paramStructConvertFunc.apply(param);
}
catch (SpannerDataException e) {
throw new IllegalArgumentException(errorMessage, e);
}
}
builder = (Statement.Builder) toMethod.apply(builder.bind(tags.get(i)),
param);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public void writeTest() {
when(bytesFieldBinder.to((ByteArray) any())).thenReturn(null);
when(writeBuilder.set(eq("bytes"))).thenReturn(bytesFieldBinder);

this.spannerEntityWriter.write(t, writeBuilder);
this.spannerEntityWriter.write(t, writeBuilder::set);

verify(idBinder, times(1)).to(eq(t.id));
verify(stringFieldBinder, times(1)).to(eq(t.stringField));
Expand Down Expand Up @@ -233,7 +233,7 @@ public void writeSomeColumnsTest() throws ClassNotFoundException {
when(booleanFieldBinder.to((Boolean) any())).thenReturn(null);
when(writeBuilder.set(eq("booleanField"))).thenReturn(booleanFieldBinder);

this.spannerEntityWriter.write(t, writeBuilder,
this.spannerEntityWriter.write(t, writeBuilder::set,
new HashSet<>(Arrays.asList("id", "custom_col")));

verify(idBinder, times(1)).to(eq(t.id));
Expand All @@ -246,15 +246,15 @@ public void writeUnsupportedTypeIterableTest() {
FaultyTestEntity2 ft = new FaultyTestEntity2();
ft.listWithUnsupportedInnerType = new ArrayList<TestEntity>();
WriteBuilder writeBuilder = Mutation.newInsertBuilder("faulty_test_table_2");
this.spannerEntityWriter.write(ft, writeBuilder);
this.spannerEntityWriter.write(ft, writeBuilder::set);
}

@Test(expected = SpannerDataException.class)
public void writeIncompatibleTypeTest() {
FaultyTestEntity ft = new FaultyTestEntity();
ft.fieldWithUnsupportedType = new TestEntity();
WriteBuilder writeBuilder = Mutation.newInsertBuilder("faulty_test_table");
this.spannerEntityWriter.write(ft, writeBuilder);
this.spannerEntityWriter.write(ft, writeBuilder::set);
}

@Test(expected = IllegalArgumentException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import java.util.ArrayList;
import java.util.List;

import com.google.cloud.spanner.Struct;
import org.junit.Test;
import org.junit.runner.RunWith;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gcp.data.spanner.test.AbstractSpannerIntegrationTest;
import org.springframework.cloud.gcp.data.spanner.test.domain.SymbolAction;
import org.springframework.cloud.gcp.data.spanner.test.domain.Trade;
import org.springframework.cloud.gcp.data.spanner.test.domain.TradeProjection;
import org.springframework.cloud.gcp.data.spanner.test.domain.TradeRepository;
Expand Down Expand Up @@ -138,6 +140,13 @@ public void declarativeQueryMethodTest() {
this.tradeRepository.findBySymbolContains("BCD")
.forEach(x -> assertEquals("ABCD", x.getSymbol()));
assertTrue(this.tradeRepository.findBySymbolNotContains("BCD").isEmpty());

assertEquals(3, this.tradeRepository
.findBySymbolAndActionPojo(new SymbolAction("ABCD", "BUY")).size());
assertEquals(3,
this.tradeRepository.findBySymbolAndActionStruct(Struct.newBuilder()
.set("symbol").to("ABCD").set("action").to("BUY").build())
.size());
}

protected List<Trade> insertTrades(String traderId1, String action, int numTrades) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.Value;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -51,6 +52,7 @@
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand All @@ -75,11 +77,14 @@ public class SqlSpannerQueryTests {

private final Pageable pageable = PageRequest.of(3, 10, this.sort);

private final SpannerEntityProcessor spannerEntityProcessor = mock(
SpannerEntityProcessor.class);

@Before
public void initMocks() {
this.queryMethod = mock(QueryMethod.class);
this.spannerTemplate = spy(new SpannerTemplate(mock(DatabaseClient.class),
new SpannerMappingContext(), mock(SpannerEntityProcessor.class),
new SpannerMappingContext(), this.spannerEntityProcessor,
mock(SpannerMutationFactory.class)));
this.expressionParser = new SpelExpressionParser();
this.evaluationContextProvider = mock(EvaluationContextProvider.class);
Expand All @@ -102,23 +107,27 @@ public void compoundNameConventionTest() {
+ "price<>#{#tag4 * -1} AND " + "( action=@tag0 AND ticker=@tag1 ) OR "
+ "( trader_id=@tag2 AND price<@tag3 ) OR ( price>=@tag4 AND id<>NULL AND "
+ "trader_id=NULL AND trader_id LIKE %@tag5 AND price=TRUE AND price=FALSE AND "
+ "struct_val = @tag8 AND struct_val = @tag9 "
+ "price>@tag6 AND price<=@tag7 )ORDER BY id DESC LIMIT 3;";

String entityResolvedSql = "SELECT * FROM (SELECT DISTINCT * FROM " + "trades@{index=fakeindex}"
+ " WHERE price=@SpELtag1 AND price<>@SpELtag1 OR price<>@SpELtag2 AND "
+ "( action=@tag0 AND ticker=@tag1 ) OR "
+ "( trader_id=@tag2 AND price<@tag3 ) OR ( price>=@tag4 AND id<>NULL AND "
+ "trader_id=NULL AND trader_id LIKE %@tag5 AND price=TRUE AND price=FALSE AND "
+ "struct_val = @tag8 AND struct_val = @tag9 "
+ "price>@tag6 AND price<=@tag7 )ORDER BY id DESC LIMIT 3) "
+ "ORDER BY COLA ASC , COLB DESC LIMIT 10 OFFSET 30";

Object[] params = new Object[] { "BUY", this.pageable, "abcd", "abc123", 8.88,
3.33, "blahblah",
1.11, 2.22, };
1.11, 2.22, Struct.newBuilder().set("symbol").to("ABCD").set("action")
.to("BUY").build(),
new SymbolAction("ABCD", "BUY") };

String[] paramNames = new String[] { "tag0", "ignoredPageable", "tag1", "tag2",
"tag3", "tag4",
"tag5", "tag6", "tag7" };
"tag5", "tag6", "tag7", "tag8", "tag9" };

Parameters parameters = mock(Parameters.class);

Expand Down Expand Up @@ -167,6 +176,8 @@ public void compoundNameConventionTest() {
assertEquals(params[6], paramMap.get("tag5").getString());
assertEquals(params[7], paramMap.get("tag6").getFloat64());
assertEquals(params[8], paramMap.get("tag7").getFloat64());
assertEquals(params[9], paramMap.get("tag8").getStruct());
verify(this.spannerEntityProcessor, times(1)).write(same(params[10]), any());
assertEquals(-8.88, paramMap.get("SpELtag1").getFloat64(), 0.00001);
assertEquals(-3.33, paramMap.get("SpELtag2").getFloat64(), 0.00001);

Expand Down Expand Up @@ -226,6 +237,17 @@ public void mutlipleSortTest() {
sqlSpannerQuery.execute(new Object[] { this.sort, this.sort });
}

private static class SymbolAction {
String symbol;

String action;

SymbolAction(String s, String a) {
this.symbol = s;
this.action = a;
}
}

@Table(name = "trades")
private static class Trade {
@PrimaryKey
Expand Down
Loading