diff options
Diffstat (limited to 'rust-analyzer-chromiumos-wrapper/src/main.rs')
-rw-r--r-- | rust-analyzer-chromiumos-wrapper/src/main.rs | 364 |
1 files changed, 364 insertions, 0 deletions
diff --git a/rust-analyzer-chromiumos-wrapper/src/main.rs b/rust-analyzer-chromiumos-wrapper/src/main.rs new file mode 100644 index 00000000..f59af454 --- /dev/null +++ b/rust-analyzer-chromiumos-wrapper/src/main.rs @@ -0,0 +1,364 @@ +// Copyright 2022 The ChromiumOS Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::env; +use std::fs::File; +use std::io::{self, BufRead, BufReader, BufWriter, Write}; +use std::os::unix::process::CommandExt; +use std::path::{Path, PathBuf}; +use std::process::{self, Child}; +use std::str::from_utf8; +use std::thread; + +use anyhow::{anyhow, bail, Context, Result}; + +use log::trace; + +use simplelog::{Config, LevelFilter, WriteLogger}; + +use serde_json::{from_slice, to_writer, Value}; + +fn main() -> Result<()> { + let args = env::args().skip(1); + + let d = env::current_dir()?; + let chromiumos_root = match find_chromiumos_root(&d) { + Some(x) => x, + None => { + // It doesn't appear that we're in a chroot. Run the + // regular rust-analyzer. + return Err(process::Command::new("rust-analyzer").args(args).exec())?; + } + }; + + let args: Vec<String> = args.collect(); + if !args.is_empty() { + // We've received command line arguments, and there are 3 possibilities: + // * We just forward the arguments to rust-analyzer and exit. + // * We don't support the arguments, so we bail. + // * We still need to do our path translation in the LSP protocol. + fn run(args: &[String]) -> Result<()> { + return Err(process::Command::new("cros_sdk") + .args(["--", "rust-analyzer"]) + .args(args) + .exec())?; + } + + if args.iter().any(|x| match x.as_str() { + "--version" | "--help" | "-h" | "--print-config-schema" => true, + _ => false, + }) { + // With any of these options rust-analyzer will just print something and exit. + return run(&args); + } + + if !args[0].starts_with("-") { + // It's a subcommand, and seemingly none of these need the path translation + // rust-analyzer-chromiumos-wrapper provides. + return run(&args); + } + + if args.iter().any(|x| x == "--log-file") { + bail!("rust-analyzer-chromiums_wrapper doesn't support --log-file"); + } + + // Otherwise it seems we're probably OK to proceed. + } + + init_log()?; + + let outside_prefix: &'static str = { + let path = chromiumos_root + .to_str() + .ok_or_else(|| anyhow!("Path is not valid UTF-8"))?; + + let mut tmp = format!("file://{}", path); + if Some(&b'/') != tmp.as_bytes().last() { + tmp.push('/'); + } + + // No need to ever free this memory, so let's get a static reference. + Box::leak(tmp.into_boxed_str()) + }; + + trace!("Found chromiumos root {}", outside_prefix); + + let inside_prefix: &'static str = "file:///mnt/host/source/"; + + let cmd = "cros_sdk"; + let all_args = ["--", "rust-analyzer"] + .into_iter() + .chain(args.iter().map(|x| x.as_str())); + let mut child = KillOnDrop(run_command(cmd, all_args)?); + + let mut child_stdin = BufWriter::new(child.0.stdin.take().unwrap()); + let mut child_stdout = BufReader::new(child.0.stdout.take().unwrap()); + + let join_handle = { + thread::spawn(move || { + let mut stdin = io::stdin().lock(); + stream_with_replacement(&mut stdin, &mut child_stdin, outside_prefix, inside_prefix) + .context("Streaming from stdin into rust-analyzer") + }) + }; + + let mut stdout = BufWriter::new(io::stdout().lock()); + stream_with_replacement( + &mut child_stdout, + &mut stdout, + inside_prefix, + outside_prefix, + ) + .context("Streaming from rust-analyzer into stdout")?; + + join_handle.join().unwrap()?; + + let code = child.0.wait().context("Running rust-analyzer")?.code(); + std::process::exit(code.unwrap_or(127)); +} + +fn init_log() -> Result<()> { + if !cfg!(feature = "no_debug_log") { + let filename = env::var("RUST_ANALYZER_CHROMIUMOS_WRAPPER_LOG") + .context("Obtaining RUST_ANALYZER_CHROMIUMOS_WRAPPER_LOG environment variable")?; + let file = File::create(&filename).with_context(|| { + format!( + "Opening log file `{}` (value of RUST_ANALYZER_WRAPPER_LOG)", + filename + ) + })?; + WriteLogger::init(LevelFilter::Trace, Config::default(), file) + .with_context(|| format!("Creating WriteLogger with log file `{}`", filename))?; + } + Ok(()) +} + +#[derive(Debug, Default)] +struct Header { + length: Option<usize>, + other_fields: Vec<u8>, +} + +/// Read the `Content-Length` (if present) into `header.length`, and the text of every other header +/// field into `header.other_fields`. +fn read_header<R: BufRead>(r: &mut R, header: &mut Header) -> Result<()> { + header.length = None; + header.other_fields.clear(); + const CONTENT_LENGTH: &[u8] = b"Content-Length:"; + let slen = CONTENT_LENGTH.len(); + loop { + let index = header.other_fields.len(); + + // HTTP header spec says line endings are supposed to be '\r\n' but recommends + // implementations accept just '\n', so let's not worry whether a '\r' is present. + r.read_until(b'\n', &mut header.other_fields) + .context("Reading a header")?; + + let new_len = header.other_fields.len(); + + if new_len <= index + 2 { + // Either we've just received EOF, or just a newline, indicating end of the header. + return Ok(()); + } + if header + .other_fields + .get(index..index + slen) + .map_or(false, |v| v == CONTENT_LENGTH) + { + let s = from_utf8(&header.other_fields[index + slen..]) + .context("Parsing Content-Length")?; + header.length = Some(s.trim().parse().context("Parsing Content-Length")?); + header.other_fields.truncate(index); + } + } +} + +/// Extend `dest` with `contents`, replacing any occurrence of `pattern` in a json string in +/// `contents` with `replacement`. +fn replace(contents: &[u8], pattern: &str, replacement: &str, dest: &mut Vec<u8>) -> Result<()> { + fn map_value(val: Value, pattern: &str, replacement: &str) -> Value { + match val { + Value::String(s) => + // `s.replace` is very likely doing more work than necessary. Probably we only need + // to look for the pattern at the beginning of the string. + { + Value::String(s.replace(pattern, replacement)) + } + Value::Array(mut v) => { + for val_ref in v.iter_mut() { + let value = std::mem::replace(val_ref, Value::Null); + *val_ref = map_value(value, pattern, replacement); + } + Value::Array(v) + } + Value::Object(mut map) => { + // Surely keys can't be paths. + for val_ref in map.values_mut() { + let value = std::mem::replace(val_ref, Value::Null); + *val_ref = map_value(value, pattern, replacement); + } + Value::Object(map) + } + x => x, + } + } + + let init_val: Value = from_slice(contents).with_context(|| match from_utf8(contents) { + Err(_) => format!( + "JSON parsing content of length {} that's not valid UTF-8", + contents.len() + ), + Ok(s) => format!("JSON parsing content of length {}:\n{}", contents.len(), s), + })?; + let mapped_val = map_value(init_val, pattern, replacement); + to_writer(dest, &mapped_val)?; + Ok(()) +} + +/// Read LSP messages from `r`, replacing each occurrence of `pattern` in a json string in the +/// payload with `replacement`, adjusting the `Content-Length` in the header to match, and writing +/// the result to `w`. +fn stream_with_replacement<R: BufRead, W: Write>( + r: &mut R, + w: &mut W, + pattern: &str, + replacement: &str, +) -> Result<()> { + let mut head = Header::default(); + let mut buf = Vec::with_capacity(1024); + let mut buf2 = Vec::with_capacity(1024); + loop { + read_header(r, &mut head)?; + if head.length.is_none() && head.other_fields.len() == 0 { + // No content in the header means we're apparently done. + return Ok(()); + } + let len = head + .length + .ok_or_else(|| anyhow!("No Content-Length in header"))?; + + trace!("Received header with length {}", head.length.unwrap()); + trace!( + "Received header with contents\n{}", + from_utf8(&head.other_fields)? + ); + + buf.resize(len, 0); + r.read_exact(&mut buf) + .with_context(|| format!("Reading payload expecting size {}", len))?; + + trace!("Received payload\n{}", from_utf8(&buf)?); + + buf2.clear(); + replace(&buf, pattern, replacement, &mut buf2)?; + + trace!("After replacements payload\n{}", from_utf8(&buf2)?); + + write!(w, "Content-Length: {}\r\n", buf2.len())?; + w.write_all(&head.other_fields)?; + w.write_all(&buf2)?; + w.flush()?; + } +} + +fn run_command<'a, I>(cmd: &'a str, args: I) -> Result<process::Child> +where + I: IntoIterator<Item = &'a str>, +{ + Ok(process::Command::new(cmd) + .args(args) + .stdin(process::Stdio::piped()) + .stdout(process::Stdio::piped()) + .spawn()?) +} + +fn find_chromiumos_root(start: &Path) -> Option<PathBuf> { + let mut buf = start.to_path_buf(); + loop { + buf.push(".chroot_lock"); + if buf.exists() { + buf.pop(); + return Some(buf); + } + buf.pop(); + if !buf.pop() { + return None; + } + } +} + +struct KillOnDrop(Child); + +impl Drop for KillOnDrop { + fn drop(&mut self) { + let _ = self.0.kill(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn test_stream_with_replacement( + read: &str, + pattern: &str, + replacement: &str, + json_expected: &str, + ) -> Result<()> { + let mut w = Vec::<u8>::with_capacity(read.len()); + stream_with_replacement(&mut read.as_bytes(), &mut w, pattern, replacement)?; + + // serde_json may not format the json output the same as we do, so we can't just compare + // as strings or slices. + + let (w1, w2) = { + let mut split = w.rsplitn(2, |&c| c == b'\n'); + let w2 = split.next().unwrap(); + (split.next().unwrap(), w2) + }; + + assert_eq!( + from_utf8(w1)?, + format!("Content-Length: {}\r\n\r", w2.len()) + ); + + let v1: Value = from_slice(w2)?; + let v2: Value = serde_json::from_str(json_expected)?; + assert_eq!(v1, v2); + + Ok(()) + } + + #[test] + fn test_stream_with_replacement_1() -> Result<()> { + test_stream_with_replacement( + // read + "Content-Length: 93\r\n\r\n{\"somekey\": {\"somepath\": \"XYZXYZabc\",\ + \"anotherpath\": \"somestring\"}, \"anotherkey\": \"XYZXYZdef\"}", + // pattern + "XYZXYZ", + // replacement + "REPLACE", + // json_expected + "{\"somekey\": {\"somepath\": \"REPLACEabc\", \"anotherpath\": \"somestring\"},\ + \"anotherkey\": \"REPLACEdef\"}", + ) + } + + #[test] + fn test_stream_with_replacement_2() -> Result<()> { + test_stream_with_replacement( + // read + "Content-Length: 83\r\n\r\n{\"key0\": \"sometextABCDEF\",\ + \"key1\": {\"key2\": 5, \"key3\": \"moreABCDEFtext\"}, \"key4\": 1}", + // pattern + "ABCDEF", + // replacement + "replacement", + // json_expected + "{\"key0\": \"sometextreplacement\", \"key1\": {\"key2\": 5,\ + \"key3\": \"morereplacementtext\"}, \"key4\": 1}", + ) + } +} |