diff --git a/benchmark/check_until.yaml b/benchmark/check_until.yaml new file mode 100644 index 0000000000..659fee28b7 --- /dev/null +++ b/benchmark/check_until.yaml @@ -0,0 +1,15 @@ +prelude: |- + $LOAD_PATH.unshift(File.expand_path("lib")) + require "strscan" + scanner = StringScanner.new("test string") + str = "string" + reg = /string/ +benchmark: + regexp: | + scanner.check_until(/string/) + regexp_var: | + scanner.check_until(reg) + string: | + scanner.check_until("string") + string_var: | + scanner.check_until(str) diff --git a/ext/jruby/org/jruby/ext/strscan/RubyStringScanner.java b/ext/jruby/org/jruby/ext/strscan/RubyStringScanner.java index 54a9dfae39..db33717881 100644 --- a/ext/jruby/org/jruby/ext/strscan/RubyStringScanner.java +++ b/ext/jruby/org/jruby/ext/strscan/RubyStringScanner.java @@ -262,17 +262,6 @@ private IRubyObject extractBegLen(Ruby runtime, int beg, int len) { // MRI: strscan_do_scan private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succptr, boolean getstr, boolean headonly) { final Ruby runtime = context.runtime; - - if (headonly) { - if (!(regex instanceof RubyRegexp)) { - regex = regex.convertToString(); - } - } else { - if (!(regex instanceof RubyRegexp)) { - throw runtime.newTypeError("wrong argument type " + regex.getMetaClass() + " (expected Regexp)"); - } - } - check(context); ByteList strBL = str.getByteList(); @@ -310,9 +299,9 @@ private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succp } if (ret < 0) return context.nil; } else { - RubyString pattern = (RubyString) regex; + RubyString pattern = regex.convertToString(); - str.checkEncoding(pattern); + Encoding patternEnc = str.checkEncoding(pattern); if (restLen() < pattern.size()) { return context.nil; @@ -321,11 +310,18 @@ private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succp ByteList patternBL = pattern.getByteList(); int patternSize = patternBL.realSize(); - if (ByteList.memcmp(strBL.unsafeBytes(), strBeg + curr, patternBL.unsafeBytes(), patternBL.begin(), patternSize) != 0) { - return context.nil; + if (headonly) { + if (ByteList.memcmp(strBL.unsafeBytes(), strBeg + curr, patternBL.unsafeBytes(), patternBL.begin(), patternSize) != 0) { + return context.nil; + } + setRegisters(patternSize); + } else { + int pos = StringSupport.index(strBL, patternBL, strBeg + curr, patternEnc); + if (pos == -1) { + return context.nil; + } + setRegisters(patternSize + pos - curr); } - - setRegisters(patternSize); } setMatched(); diff --git a/ext/strscan/strscan.c b/ext/strscan/strscan.c index fad35925a8..45aa1e38e0 100644 --- a/ext/strscan/strscan.c +++ b/ext/strscan/strscan.c @@ -686,14 +686,6 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly { struct strscanner *p; - if (headonly) { - if (!RB_TYPE_P(pattern, T_REGEXP)) { - StringValue(pattern); - } - } - else { - Check_Type(pattern, T_REGEXP); - } GET_SCANNER(self, p); CLEAR_MATCH_STATUS(p); @@ -714,14 +706,25 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly } } else { + StringValue(pattern); rb_enc_check(p->str, pattern); if (S_RESTLEN(p) < RSTRING_LEN(pattern)) { return Qnil; } - if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) { - return Qnil; + + if (headonly) { + if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) { + return Qnil; + } + set_registers(p, RSTRING_LEN(pattern)); + } else { + long pos = rb_memsearch(RSTRING_PTR(pattern), RSTRING_LEN(pattern), + CURPTR(p), S_RESTLEN(p), rb_enc_get(pattern)); + if (pos == -1) { + return Qnil; + } + set_registers(p, RSTRING_LEN(pattern) + pos); } - set_registers(p, RSTRING_LEN(pattern)); } MATCHED(p); diff --git a/test/strscan/test_stringscanner.rb b/test/strscan/test_stringscanner.rb index 143cf7197d..9b7b7910d0 100644 --- a/test/strscan/test_stringscanner.rb +++ b/test/strscan/test_stringscanner.rb @@ -262,7 +262,7 @@ def test_concat end def test_scan - s = create_string_scanner('stra strb strc', true) + s = create_string_scanner("stra strb\0strc", true) tmp = s.scan(/\w+/) assert_equal 'stra', tmp @@ -270,7 +270,7 @@ def test_scan assert_equal ' ', tmp assert_equal 'strb', s.scan(/\w+/) - assert_equal ' ', s.scan(/\s+/) + assert_equal "\u0000", s.scan(/\0/) tmp = s.scan(/\w+/) assert_equal 'strc', tmp @@ -312,11 +312,14 @@ def test_scan end def test_scan_string - s = create_string_scanner('stra strb strc') + s = create_string_scanner("stra strb\0strc") assert_equal 'str', s.scan('str') assert_equal 'str', s[0] assert_equal 3, s.pos assert_equal 'a ', s.scan('a ') + assert_equal 'strb', s.scan('strb') + assert_equal "\u0000", s.scan("\0") + assert_equal 'strc', s.scan('strc') str = 'stra strb strc'.dup s = create_string_scanner(str, false) @@ -668,13 +671,47 @@ def test_exist_p assert_equal(nil, s.exist?(/e/)) end - def test_exist_p_string + def test_exist_p_invalid_argument s = create_string_scanner("test string") assert_raise(TypeError) do - s.exist?(" ") + s.exist?(1) end end + def test_exist_p_string + omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby" + s = create_string_scanner("test string") + assert_equal(3, s.exist?("s")) + assert_equal(0, s.pos) + s.scan("test") + assert_equal(2, s.exist?("s")) + assert_equal(4, s.pos) + assert_equal(nil, s.exist?("e")) + end + + def test_scan_until + s = create_string_scanner("Foo Bar\0Baz") + assert_equal("Foo", s.scan_until(/Foo/)) + assert_equal(3, s.pos) + assert_equal(" Bar", s.scan_until(/Bar/)) + assert_equal(7, s.pos) + assert_equal(nil, s.skip_until(/Qux/)) + assert_equal("\u0000Baz", s.scan_until(/Baz/)) + assert_equal(11, s.pos) + end + + def test_scan_until_string + omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby" + s = create_string_scanner("Foo Bar\0Baz") + assert_equal("Foo", s.scan_until("Foo")) + assert_equal(3, s.pos) + assert_equal(" Bar", s.scan_until("Bar")) + assert_equal(7, s.pos) + assert_equal(nil, s.skip_until("Qux")) + assert_equal("\u0000Baz", s.scan_until("Baz")) + assert_equal(11, s.pos) + end + def test_skip_until s = create_string_scanner("Foo Bar Baz") assert_equal(3, s.skip_until(/Foo/)) @@ -684,6 +721,16 @@ def test_skip_until assert_equal(nil, s.skip_until(/Qux/)) end + def test_skip_until_string + omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby" + s = create_string_scanner("Foo Bar Baz") + assert_equal(3, s.skip_until("Foo")) + assert_equal(3, s.pos) + assert_equal(4, s.skip_until("Bar")) + assert_equal(7, s.pos) + assert_equal(nil, s.skip_until("Qux")) + end + def test_check_until s = create_string_scanner("Foo Bar Baz") assert_equal("Foo", s.check_until(/Foo/)) @@ -693,6 +740,16 @@ def test_check_until assert_equal(nil, s.check_until(/Qux/)) end + def test_check_until_string + omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby" + s = create_string_scanner("Foo Bar Baz") + assert_equal("Foo", s.check_until("Foo")) + assert_equal(0, s.pos) + assert_equal("Foo Bar", s.check_until("Bar")) + assert_equal(0, s.pos) + assert_equal(nil, s.check_until("Qux")) + end + def test_search_full s = create_string_scanner("Foo Bar Baz") assert_equal(8, s.search_full(/Bar /, false, false)) @@ -705,6 +762,19 @@ def test_search_full assert_equal(11, s.pos) end + def test_search_full_string + omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby" + s = create_string_scanner("Foo Bar Baz") + assert_equal(8, s.search_full("Bar ", false, false)) + assert_equal(0, s.pos) + assert_equal("Foo Bar ", s.search_full("Bar ", false, true)) + assert_equal(0, s.pos) + assert_equal(8, s.search_full("Bar ", true, false)) + assert_equal(8, s.pos) + assert_equal("Baz", s.search_full("az", true, true)) + assert_equal(11, s.pos) + end + def test_peek s = create_string_scanner("test string") assert_equal("test st", s.peek(7))