diff --git a/src/xterm.rs b/src/xterm.rs index 7e77174..e33e804 100644 --- a/src/xterm.rs +++ b/src/xterm.rs @@ -13,6 +13,10 @@ const QUERY_FG: &[u8] = b"\x1b]10;?"; const FG_RESPONSE_PREFIX: &[u8] = b"\x1b]10;"; const QUERY_BG: &[u8] = b"\x1b]11;?"; const BG_RESPONSE_PREFIX: &[u8] = b"\x1b]11;"; +const ST: &[u8] = b"\x1b\\"; +const DA1: &[u8] = b"\x1b[c"; +const ESC: u8 = 0x1b; +const BEL: u8 = 0x07; pub(crate) fn foreground_color(options: QueryOptions) -> Result { let quirks = terminal_quirks_from_env(); @@ -68,11 +72,6 @@ fn map_timed_out_err(timeout: Duration) -> impl Fn(Error) -> Error { } } -const ST: &[u8] = b"\x1b\\"; -const DA1: &[u8] = b"\x1b[c"; -const ESC: u8 = 0x1b; -const BEL: u8 = 0x07; - fn parse_response(response: Vec, prefix: &[u8]) -> Result { response .strip_prefix(prefix) diff --git a/src/xterm/quirks.rs b/src/xterm/quirks.rs index 2765d9e..2132376 100644 --- a/src/xterm/quirks.rs +++ b/src/xterm/quirks.rs @@ -1,3 +1,4 @@ +use super::{BEL, ST}; use std::env; use std::io::{self, Write}; use std::sync::OnceLock; @@ -14,6 +15,7 @@ fn terminal_quirk_from_env_eager() -> TerminalQuirks { match env::var("TERM") { Ok(term) if term == "dumb" => Barebones, Ok(term) if term == "rxvt-unicode" || term.starts_with("rxvt-unicode-") => Urxvt, + Ok(term) if term == "screen" || term.starts_with("screen.") => Screen, Ok(_) | Err(_) => None, } } @@ -23,6 +25,7 @@ pub(super) enum TerminalQuirks { None, Barebones, Urxvt, + Screen, } impl TerminalQuirks { @@ -33,9 +36,6 @@ impl TerminalQuirks { } pub(super) fn string_terminator(&self) -> &[u8] { - const ST: &[u8] = b"\x1b\\"; - const BEL: u8 = 0x07; - if let TerminalQuirks::Urxvt = self { // The currently released version has a bug where it terminates the response with `ESC` instead of `ST`. // Fixed by revision [1.600](http://cvs.schmorp.de/rxvt-unicode/src/command.C?revision=1.600&view=markup). @@ -47,10 +47,119 @@ impl TerminalQuirks { } pub(super) fn write_all(&self, w: &mut dyn Write, bytes: &[u8]) -> io::Result<()> { - w.write_all(bytes) + if let TerminalQuirks::Screen = self { + screen::write_to_host_terminal(w, bytes) + } else { + w.write_all(bytes) + } } pub(super) fn write_string_terminator(&self, writer: &mut dyn Write) -> io::Result<()> { self.write_all(writer, self.string_terminator()) } } + +// Screen breaks one of our fundamental assumptions: +// It responds to `DA1` *before* responding to `OSC 10`. +// To work around this we wrap our query in a `DCS` / `ST` pair. +// +// This directs screen to send our query to the underlying terminal instead of +// interpreting our query itself. Hopefully the underlying terminal is more +// *sensible* about order... +mod screen { + use super::*; + use crate::xterm::{ESC, ST}; + use memchr::memchr_iter; + + /// From the [manual](https://www.gnu.org/software/screen/manual/html_node/Control-Sequences.html): + /// > Device Control String \ + /// > Outputs a string directly to the host + /// > terminal without interpretation. + const DCS: &[u8] = b"\x1bP"; + + pub(super) fn write_to_host_terminal(w: &mut dyn Write, mut bytes: &[u8]) -> io::Result<()> { + loop { + // If our query contains `ST` we need to split it across multiple + // `DCS` / `ST` pairs to avoid screen from interpreting our `ST` as + // the terminator for the `DCS` sequence. + if let Some(index) = find_st(bytes) { + write_to_host_terminal_unchecked(w, &bytes[..index])?; + write_to_host_terminal_unchecked(w, &[ESC])?; + write_to_host_terminal_unchecked(w, &[b'\\'])?; + bytes = &bytes[(index + ST.len())..]; + } else { + write_to_host_terminal_unchecked(w, bytes)?; + break; + } + } + + Ok(()) + } + + fn write_to_host_terminal_unchecked(w: &mut dyn Write, bytes: &[u8]) -> io::Result<()> { + if !bytes.is_empty() { + w.write_all(DCS)?; + w.write_all(bytes)?; + w.write_all(ST)?; + } + Ok(()) + } + + fn find_st(haystack: &[u8]) -> Option { + memchr_iter(ESC, haystack) + .filter_map(|index| { + let next_byte = *haystack.get(index + 1)?; + (next_byte == b'\\').then_some(index) + }) + .next() + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::fmt::CaretNotation; + + #[test] + fn wraps_query_between_dcs_and_st() { + let expected = b"\x1bP\x1b[c\x1b\\"; + let mut actual = Vec::new(); + write_to_host_terminal(&mut actual, b"\x1b[c").unwrap_or_else(|_| unreachable!()); + assert_eq!(to_string(expected.as_slice()), to_string(&actual)); + } + + #[test] + fn splits_st_among_multiple_dcs_and_st_pairs() { + let expected = b"\x1bP\x1b]11;?\x1b\\\x1bP\x1b\x1b\\\x1bP\\\x1b\\"; + let mut actual = Vec::new(); + write_to_host_terminal(&mut actual, b"\x1b]11;?\x1b\\") + .unwrap_or_else(|_| unreachable!()); + assert_eq!(to_string(expected.as_slice()), to_string(&actual)); + } + + #[test] + fn finds_st_at_start() { + assert_eq!(Some(0), find_st(ST)); + } + + #[test] + fn finds_st_after_esc() { + assert_eq!(Some(1), find_st(&[ESC, ESC, b'\\'])) + } + + #[test] + fn finds_first_esc() { + assert_eq!( + Some(3), + find_st(&[b'f', b'o', b'o', ESC, b'\\', ESC, b'\\']) + ) + } + + fn to_string(input: &[u8]) -> String { + use std::str::from_utf8; + format!( + "{}", + CaretNotation(from_utf8(input).expect("valid utf-8 data")) + ) + } + } +}