From df5bfa6b80ba419042498b343aed697b7807b810 Mon Sep 17 00:00:00 2001 From: Terry Chow <32403408+tkyc@users.noreply.github.com> Date: Wed, 22 May 2024 10:48:20 -0700 Subject: [PATCH] Execute stored procedures directly for RPC calls (#2410) * RPC fix * Assertion index correction * Removed magic number * Removed login drop command * Code review changes * Formatting --- .../jdbc/SQLServerCallableStatement.java | 16 ++++++++++++---- .../jdbc/SQLServerPreparedStatement.java | 17 +---------------- .../sqlserver/jdbc/SQLServerStatement.java | 4 ++-- .../sqlserver/jdbc/StreamRetValue.java | 6 +++++- .../CallableStatementTest.java | 11 +++++++++++ 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerCallableStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerCallableStatement.java index f98585fe4..270df7ee7 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerCallableStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerCallableStatement.java @@ -216,7 +216,7 @@ && callRPCDirectly(inOutParam)) { } if (inOutParam[i - 1].isReturnValue() && bReturnValueSyntax && !isCursorable(executeMethod) && !isTVPType - && returnValueStatus != userDefinedFunctionReturnStatus) { + && returnValueStatus != USER_DEFINED_FUNCTION_RETURN_STATUS) { return inOutParam[i - 1]; } @@ -269,7 +269,6 @@ final void processOutParameters() throws SQLServerException { // the response stream. for (int index = 0; index < inOutParam.length; ++index) { if (index != outParamIndex && inOutParam[index].isValueGotten()) { - assert inOutParam[index].isOutput(); inOutParam[index].resetOutputValue(); } } @@ -365,7 +364,7 @@ boolean onRetValue(TDSReader tdsReader) throws SQLServerException { OutParamHandler outParamHandler = new OutParamHandler(); if (bReturnValueSyntax && (nOutParamsAssigned == 0) && !isCursorable(executeMethod) && !isTVPType - && callRPCDirectly(inOutParam) && returnValueStatus != userDefinedFunctionReturnStatus) { + && callRPCDirectly(inOutParam) && returnValueStatus != USER_DEFINED_FUNCTION_RETURN_STATUS) { nOutParamsAssigned++; } @@ -414,7 +413,7 @@ && callRPCDirectly(inOutParam) && returnValueStatus != userDefinedFunctionReturn outParamIndex = outParamHandler.srv.getOrdinalOrLength(); if (bReturnValueSyntax && !isCursorable(executeMethod) && !isTVPType && callRPCDirectly(inOutParam) - && returnValueStatus != userDefinedFunctionReturnStatus) { + && returnValueStatus != USER_DEFINED_FUNCTION_RETURN_STATUS) { outParamIndex++; } else { // Statements need to have their out param indices adjusted by the number @@ -424,10 +423,19 @@ && callRPCDirectly(inOutParam) && returnValueStatus != userDefinedFunctionReturn if ((outParamIndex < 0 || outParamIndex >= inOutParam.length) || (!inOutParam[outParamIndex].isOutput())) { + + // For RPC calls with out parameters, the initial return value token will indicate + // it being a RPC. In such case, consume the token as it does not contain the out parameter + // value. The subsequent token will have the value. + if (outParamHandler.srv.getStatus() == USER_DEFINED_FUNCTION_RETURN_STATUS) { + continue; + } + if (getStatementLogger().isLoggable(java.util.logging.Level.INFO)) { getStatementLogger().info(toString() + " Unexpected outParamIndex: " + outParamIndex + "; adjustment: " + outParamIndexAdjustment); } + connection.throwInvalidTDS(); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 5f1766d94..f3a441581 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -77,9 +77,6 @@ public class SQLServerPreparedStatement extends SQLServerStatement implements IS // flag whether is call escape syntax private boolean isCallEscapeSyntax; - // flag whether is four part syntax - private boolean isFourPartSyntax; - /** Parameter positions in processed SQL statement text. */ final int[] userSQLParamPositions; @@ -149,11 +146,6 @@ private void setPreparedStatementHandle(int handle) { */ private static final Pattern execEscapePattern = Pattern.compile("^\\s*(?i)(?:exec|execute)\\b"); - /** - * Regex for four part syntax - */ - private static final Pattern fourPartSyntaxPattern = Pattern.compile("(.+)\\.(.+)\\.(.+)\\.(.+)"); - /** Returns the prepared statement SQL */ @Override public String toString() { @@ -290,7 +282,6 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) { userSQL = parsedSQL.processedSQL; isExecEscapeSyntax = isExecEscapeSyntax(sql); isCallEscapeSyntax = isCallEscapeSyntax(sql); - isFourPartSyntax = isFourPartSyntax(sql); userSQLParamPositions = parsedSQL.parameterPositions; initParams(userSQLParamPositions.length); useBulkCopyForBatchInsert = conn.getUseBulkCopyForBatchInsert(); @@ -1258,10 +1249,8 @@ boolean callRPCDirectly(Parameter[] params) throws SQLServerException { // 4. Compliant CALL escape syntax // If isExecEscapeSyntax is true, EXEC escape syntax is used then use prior behaviour of // wrapping call to execute the procedure - // If isFourPartSyntax is true, sproc is being executed against linked server, then - // use prior behaviour of wrapping call to execute procedure return (null != procedureName && paramCount != 0 && !isTVPType(params) && isCallEscapeSyntax - && !isExecEscapeSyntax && !isFourPartSyntax); + && !isExecEscapeSyntax); } /** @@ -1289,10 +1278,6 @@ private boolean isCallEscapeSyntax(String sql) { return callEscapePattern.matcher(sql).find(); } - private boolean isFourPartSyntax(String sql) { - return fourPartSyntaxPattern.matcher(sql).find(); - } - /** * Executes sp_prepare to prepare a parameterized statement and sets the prepared statement handle * diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java index 12432c560..f27e1b1ac 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java @@ -72,7 +72,7 @@ public class SQLServerStatement implements ISQLServerStatement { /** Check if statement contains TVP Type */ boolean isTVPType = false; - static int userDefinedFunctionReturnStatus = 2; + protected static final int USER_DEFINED_FUNCTION_RETURN_STATUS = 2; final boolean getIsResponseBufferingAdaptive() { return isResponseBufferingAdaptive; @@ -1676,7 +1676,7 @@ boolean onRetValue(TDSReader tdsReader) throws SQLServerException { // in which case we need to stop parsing and let CallableStatement take over. // A RETVALUE token appearing in the execution results, but before any RETSTATUS // token, is a TEXTPTR return value that should be ignored. - if (moreResults && null == procedureRetStatToken && status != userDefinedFunctionReturnStatus) { + if (moreResults && null == procedureRetStatToken && status != USER_DEFINED_FUNCTION_RETURN_STATUS) { Parameter p = new Parameter( Util.shouldHonorAEForParameters(stmtColumnEncriptionSetting, connection)); p.skipRetValStatus(tdsReader); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/StreamRetValue.java b/src/main/java/com/microsoft/sqlserver/jdbc/StreamRetValue.java index a4324b6a8..82aa3c639 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/StreamRetValue.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/StreamRetValue.java @@ -18,10 +18,14 @@ final class StreamRetValue extends StreamPacket { */ private int ordinalOrLength; - final int getOrdinalOrLength() { + int getOrdinalOrLength() { return ordinalOrLength; } + int getStatus() { + return status; + } + /* * Status: 0x01 if the return value is an OUTPUT parameter of a stored procedure 0x02 if the return value is from a * User Defined Function diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java index 67dde1fad..a399e53f7 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/callablestatement/CallableStatementTest.java @@ -1187,6 +1187,8 @@ public void testFourPartSyntaxCallEscapeSyntax() throws SQLException { Statement stmt = linkedServerConnection.createStatement()) { stmt.execute( "create or alter procedure dbo.TestAdd(@Num1 int, @Num2 int, @Result int output) as begin set @Result = @Num1 + @Num2; end;"); + + stmt.execute("create or alter procedure dbo.TestReturn(@Num1 int) as select @Num1 return @Num1*3 "); } try (CallableStatement cstmt = connection @@ -1211,6 +1213,15 @@ public void testFourPartSyntaxCallEscapeSyntax() throws SQLException { cstmt.execute(); assertEquals(sum, cstmt.getInt(3)); } + + try (CallableStatement cstmt = connection + .prepareCall("{? = call [" + linkedServer + "].master.dbo.TestReturn(?)}")) { + int expected = 15; + cstmt.registerOutParameter(1, java.sql.Types.INTEGER); + cstmt.setInt(2, 5); + cstmt.execute(); + assertEquals(expected, cstmt.getInt(1)); + } } /**