Skip to content

Commit

Permalink
Adding the ability to specify column names to include during insert
Browse files Browse the repository at this point in the history
  • Loading branch information
souravroy committed Feb 4, 2024
1 parent 54100df commit a353b39
Show file tree
Hide file tree
Showing 5 changed files with 514 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.springframework.data.util.Pair;
import org.springframework.web.bind.annotation.RestController;

import java.util.List;
import java.util.Map;

@RestController
Expand All @@ -19,11 +20,12 @@ public class CreateController implements CreateRestApi {
@Override
public CreateResponse save(String tableName,
String schemaName,
List<String> columns,
Map<String, Object> data,
String tsid,
String tsidType) {

Pair<Integer, Object> result = createService.save(schemaName, tableName, data, tsid, tsidType);
Pair<Integer, Object> result = createService.save(schemaName, tableName, columns, data, tsid, tsidType);

return new CreateResponse(result.getFirst(), result.getSecond());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.*;

import java.util.List;
import java.util.Map;

public interface CreateRestApi {
@ResponseStatus(HttpStatus.CREATED)
@PostMapping("/{tableName}")
CreateResponse save(@PathVariable String tableName,
@RequestHeader(name = "Content-Profile") String schemaName,
@RequestParam(name = "columns", required = false) List<String> columns,
@RequestBody Map<String, Object> data,
@RequestParam(name = "tsid", required = false) String tsid,
@RequestParam(name = "tsidType", required = false, defaultValue = "number") String tsidType);
Expand Down
167 changes: 86 additions & 81 deletions src/main/java/com/homihq/db2rest/rest/create/CreateService.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.homihq.db2rest.rest.create;

import com.homihq.db2rest.config.Db2RestConfigProperties;
import com.homihq.db2rest.exception.GenericDataAccessException;
import com.homihq.db2rest.mybatis.DB2RestRenderingStrategy;
import io.hypersistence.tsid.TSID;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.mybatis.dynamic.sql.SqlColumn;
import org.mybatis.dynamic.sql.SqlTable;
import org.mybatis.dynamic.sql.insert.BatchInsertDSL;
import org.mybatis.dynamic.sql.insert.GeneralInsertDSL;
Expand All @@ -33,64 +33,69 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.mybatis.dynamic.sql.insert.BatchInsertDSL.insert;
import static org.mybatis.dynamic.sql.insert.GeneralInsertDSL.insertInto;
import static org.springframework.util.CollectionUtils.isEmpty;


@Service
@Slf4j
@RequiredArgsConstructor
public class CreateService {

private final Db2RestConfigProperties db2RestConfigProperties;
private final NamedParameterJdbcTemplate namedParameterJdbcTemplate;
private final DB2RestRenderingStrategy db2RestRenderingStrategy = new DB2RestRenderingStrategy();
private final String DEFAULT_GENERATED_KEY_NAME = "GENERATED_KEY";
private static final String DEFAULT_GENERATED_KEY_NAME = "GENERATED_KEY";

@Transactional
public Pair<Integer, Object> save(String schemaName, String tableName, Map<String, Object> data, String tsid, String tsidType) {
//db2RestConfigProperties.verifySchema(schemaName);
@Transactional
public Pair<Integer, Object> save(String schemaName, String tableName, List<String> columns, Map<String, Object> data, String tsid, String tsidType) {
//db2RestConfigProperties.verifySchema(schemaName);

processTSID(data, tsid, tsidType);

SqlTable table = SqlTable.of(tableName);
GeneralInsertDSL generalInsertDSL = insertInto(table);

for(String key : data.keySet()) {
generalInsertDSL.set(table.column(key)).toValue(data.get(key));
if (!isEmpty(columns)) {
columns.forEach(key -> setColumnsAndValues(generalInsertDSL, table.column(key), data.get(key)));
} else {
data.keySet().forEach(key -> setColumnsAndValues(generalInsertDSL, table.column(key), data.get(key)));
}

GeneralInsertStatementProvider insertStatement = generalInsertDSL.build().render(RenderingStrategies.SPRING_NAMED_PARAMETER);

log.debug("SQL - {}", insertStatement.getInsertStatement());
log.debug("SQL - row - {}", insertStatement.getParameters());

int row;
int row;

KeyHolder keyHolder = new GeneratedKeyHolder();
try {
row = namedParameterJdbcTemplate.update(insertStatement.getInsertStatement(),
new MapSqlParameterSource(insertStatement.getParameters()),
keyHolder
);
} catch (DataAccessException e) {
throw new GenericDataAccessException(e.getMostSpecificCause().getMessage());
}
KeyHolder keyHolder = new GeneratedKeyHolder();
try {
row = namedParameterJdbcTemplate.update(insertStatement.getInsertStatement(),
new MapSqlParameterSource(insertStatement.getParameters()),
keyHolder
);
} catch (DataAccessException e) {
throw new GenericDataAccessException(e.getMostSpecificCause().getMessage());
}

String primaryColName = getPrimaryKeyColName(tableName);
Object generated_key = extractGeneratedKeys(List.of(data), keyHolder.getKeyList(), tsid, primaryColName).getFirst();

String primaryColName = getPrimaryKeyColName(tableName);
Object generated_key = extractGeneratedKeys(List.of(data), keyHolder.getKeyList(), tsid, primaryColName).getFirst();
log.debug("Inserted - {} row(s), generated key - {}", row, generated_key);

log.debug("Inserted - {} row(s), generated key - {}", row, generated_key);
return Pair.of(row, Objects.requireNonNull(generated_key));
}

return Pair.of(row, Objects.requireNonNull(generated_key));
}
private static <T> void setColumnsAndValues(GeneralInsertDSL generalInsertDSL, SqlColumn<T> column, T value) {
generalInsertDSL.set(column).toValue(value);
}

private void processTSID(Map<String, Object> data, String tsid, String tsidType) {
//1. check if tsid column is specified if yes go ahead add or update it with generated TSID value

if(StringUtils.isNotBlank(tsid)) {
if (StringUtils.isNotBlank(tsid)) {
data.put(tsid, getTSIDValue(tsidType));
}

Expand All @@ -99,10 +104,10 @@ private void processTSID(Map<String, Object> data, String tsid, String tsidType)
}

private Object getTSIDValue(String tsidType) {
if(StringUtils.equalsAnyIgnoreCase(tsidType, "number", "string")) {
if (StringUtils.equalsAnyIgnoreCase(tsidType, "number", "string")) {
return
StringUtils.equalsIgnoreCase(tsidType, "number") ? TSID.Factory.getTsid().toLong() :
TSID.Factory.getTsid().toString();
StringUtils.equalsIgnoreCase(tsidType, "number") ? TSID.Factory.getTsid().toLong() :
TSID.Factory.getTsid().toString();


}
Expand All @@ -111,28 +116,28 @@ private Object getTSIDValue(String tsidType) {

}

@Transactional
public Pair<int[], List<Object>> saveBulk(String schemaName, String tableName, List<Map<String, Object>> dataList, String tsid, String tsidType) {
if (Objects.isNull(dataList) || dataList.isEmpty()) throw new RuntimeException("No data provided");
@Transactional
public Pair<int[], List<Object>> saveBulk(String schemaName, String tableName, List<Map<String, Object>> dataList, String tsid, String tsidType) {
if (Objects.isNull(dataList) || dataList.isEmpty()) throw new RuntimeException("No data provided");

for(Map<String, Object> data : dataList)
for (Map<String, Object> data : dataList)
processTSID(data, tsid, tsidType);

SqlTable table = SqlTable.of(tableName);

BatchInsertDSL<Map<String, Object>> batchInsertDSL = insert(dataList)
.into(table);

Map<String, Object> item = dataList.getFirst();
Map<String, Object> item = dataList.getFirst();

for(String key : item.keySet()) {
for (String key : item.keySet()) {
batchInsertDSL.map(table.column(key)).toProperty(key);
}

BatchInsert<Map<String,Object>> batchInsert =
BatchInsert<Map<String, Object>> batchInsert =
batchInsertDSL
.build()
.render(db2RestRenderingStrategy);
.build()
.render(db2RestRenderingStrategy);

SqlParameterSource[] batch = SqlParameterSourceUtils.createBatch(dataList.toArray());

Expand All @@ -141,64 +146,64 @@ public Pair<int[], List<Object>> saveBulk(String schemaName, String tableName, L
log.debug("batch -> {}", batchInsert.getRecords());

int[] updateCounts;
KeyHolder keyHolder = new GeneratedKeyHolder();
KeyHolder keyHolder = new GeneratedKeyHolder();

String primaryColName = getPrimaryKeyColName(tableName);
String primaryColName = getPrimaryKeyColName(tableName);

try {
updateCounts = namedParameterJdbcTemplate.batchUpdate(batchInsert.getInsertStatementSQL(), batch, keyHolder);
} catch (DataAccessException e) {
e.printStackTrace();
throw new GenericDataAccessException(e.getMostSpecificCause().getMessage());
}
try {
updateCounts = namedParameterJdbcTemplate.batchUpdate(batchInsert.getInsertStatementSQL(), batch, keyHolder);
} catch (DataAccessException e) {
throw new GenericDataAccessException(e.getMostSpecificCause().getMessage());
}

log.debug("Update counts - {}", updateCounts.length);

List<Object> generatedKeys = extractGeneratedKeys(dataList, keyHolder.getKeyList(), tsid, primaryColName);
List<Object> generatedKeys = extractGeneratedKeys(dataList, keyHolder.getKeyList(), tsid, primaryColName);

return Pair.of(updateCounts, generatedKeys);
}

return Pair.of(updateCounts, generatedKeys);
}
private List<Object> extractGeneratedKeys(final List<Map<String, Object>> dataList, final List<Map<String, Object>> keyList, final String tsid, final String primaryColName) {
String keyName;
List<Map<String, Object>> keySource;

private List<Object> extractGeneratedKeys(final List<Map<String, Object>> dataList, final List<Map<String, Object>> keyList, final String tsid, final String primaryColName) {
String keyName;
List<Map<String, Object>> keySource;
if (StringUtils.isNotBlank(tsid)) {
keyName = tsid;
keySource = dataList;
} else {
keySource = keyList;

if (StringUtils.isNotBlank(tsid)) {
keyName = tsid;
keySource = dataList;
} else {
keySource = keyList;
boolean usingDefaultGeneratedKey = keyList.stream()
.anyMatch(map -> map.containsKey(DEFAULT_GENERATED_KEY_NAME));

boolean usingDefaultGeneratedKey = keyList.stream()
.anyMatch(map -> map.containsKey(DEFAULT_GENERATED_KEY_NAME));
keyName = usingDefaultGeneratedKey ? DEFAULT_GENERATED_KEY_NAME : primaryColName;
}

keyName = usingDefaultGeneratedKey ? DEFAULT_GENERATED_KEY_NAME : primaryColName;
}
return keySource.stream()
.map(data -> Optional.ofNullable(data.get(keyName)))
.flatMap(Optional::stream)
.toList();
}

return keySource.stream()
.map(data -> Optional.ofNullable(data.get(keyName)))
.flatMap(Optional::stream)
.collect(Collectors.toList());
}
private String getPrimaryKeyColName(final String tableName) {
DataSource dataSource = namedParameterJdbcTemplate.getJdbcTemplate().getDataSource();

private String getPrimaryKeyColName(final String tableName) {
DataSource dataSource = namedParameterJdbcTemplate.getJdbcTemplate().getDataSource();
if (dataSource == null) {
throw new IllegalStateException("DataSource is null");
}

if (dataSource == null) {
throw new IllegalStateException("DataSource is null");
}
try (Connection connection = dataSource.getConnection()) {
DatabaseMetaData databaseMetaData = connection.getMetaData();
ResultSet rs = databaseMetaData.getPrimaryKeys(null, null, tableName);

try (Connection connection = dataSource.getConnection()) {
DatabaseMetaData databaseMetaData = connection.getMetaData();
ResultSet rs = databaseMetaData.getPrimaryKeys(null, null, tableName);
if (rs.next()) {
return rs.getString("COLUMN_NAME");
}
} catch (SQLException e) {
throw new GenericDataAccessException("Error retrieving primary key column name for table " + tableName);
}

if (rs.next()) {
return rs.getString("COLUMN_NAME");
}
} catch (SQLException e) {
throw new GenericDataAccessException("Error retrieving primary key column name for table " + tableName);
}
return "";
}

return "";
}
}
Loading

0 comments on commit a353b39

Please sign in to comment.