diff --git a/src/main/java/org/jruby/ext/openssl/Cipher.java b/src/main/java/org/jruby/ext/openssl/Cipher.java index f2caf282..47b57c25 100644 --- a/src/main/java/org/jruby/ext/openssl/Cipher.java +++ b/src/main/java/org/jruby/ext/openssl/Cipher.java @@ -68,6 +68,7 @@ import org.jruby.runtime.builtin.IRubyObject; import org.jruby.runtime.Visibility; import org.jruby.util.ByteList; +import org.jruby.util.TypeConverter; import static org.jruby.ext.openssl.OpenSSL.*; @@ -1104,6 +1105,11 @@ private String getCipherAlgorithm() { @JRubyMethod public IRubyObject update(final ThreadContext context, final IRubyObject arg) { + return update(context, arg, null); + } + + @JRubyMethod + public IRubyObject update(final ThreadContext context, final IRubyObject arg, IRubyObject buffer) { final Ruby runtime = context.runtime; if ( isDebug(runtime) ) dumpVars( runtime.getOut(), "update()" ); @@ -1143,7 +1149,13 @@ public IRubyObject update(final ThreadContext context, final IRubyObject arg) { debugStackTrace( runtime, e ); throw newCipherError(runtime, e); } - return RubyString.newString(runtime, str); + + if( buffer == null ) { + return RubyString.newString(runtime, str); + } else { + buffer = TypeConverter.convertToType(buffer, context.runtime.getString(), "to_str", true); + return ((RubyString) buffer).replace(RubyString.newString(runtime, str)); + } } @JRubyMethod(name = "<<") diff --git a/src/test/ruby/test_cipher.rb b/src/test/ruby/test_cipher.rb index 8b3cdc06..e08cd33a 100644 --- a/src/test/ruby/test_cipher.rb +++ b/src/test/ruby/test_cipher.rb @@ -475,4 +475,27 @@ def test_encrypt_aes_cfb_20_incompatibility end end + def test_encrypt_aes_256_cbc_modifies_buffer + cipher = OpenSSL::Cipher.new("AES-256-CBC") + cipher.key = "a" * 32 + cipher.encrypt + buffer = '' + actual = cipher.update('bar' * 10, buffer) + if jruby? + expected = "\xE6\xD3Y\fc\xEE\xBA\xB2*\x0Fr\xD1\xC2b\x03\xD0" + else + expected = "8\xA7\xBE\xB1\xAE\x88j\xCB\xA3\xE9j\x00\xD2W_\x91" + end + assert_equal actual, expected + assert_equal buffer, expected + end + + def test_encrypt_aes_256_cbc_invalid_buffer + cipher = OpenSSL::Cipher.new("AES-256-CBC") + cipher.key = "a" * 32 + cipher.encrypt + buffer = Object.new + assert_raise(TypeError) { cipher.update('bar' * 10, buffer) } + end + end