diff options
Diffstat (limited to 'nearby/src/file_header/mod.rs')
-rw-r--r-- | nearby/src/file_header/mod.rs | 559 |
1 files changed, 559 insertions, 0 deletions
diff --git a/nearby/src/file_header/mod.rs b/nearby/src/file_header/mod.rs new file mode 100644 index 0000000..275a7d0 --- /dev/null +++ b/nearby/src/file_header/mod.rs @@ -0,0 +1,559 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Tools for checking for, or adding, headers (e.g. licenses, etc) in files. + +use std::{ + fs, + io::{self, BufRead as _, Write as _}, + iter::FromIterator, + path, thread, +}; + +pub mod license; + +/// A file header to check for or add to files. +#[derive(Clone)] +pub struct Header<C: HeaderChecker> { + checker: C, + header: String, +} + +impl<C: HeaderChecker> Header<C> { + /// Construct a new `Header` with the `checker` used to determine if the header is already + /// present, and the plain `header` text to add (without any applicable comment syntax, etc). + pub fn new(checker: C, header: String) -> Self { + Self { checker, header } + } + + /// Return true if the file has the desired header, false otherwise. + pub fn header_present(&self, input: &mut impl io::Read) -> io::Result<bool> { + self.checker.check(input) + } + + /// Add the header, with appropriate formatting for the type of file indicated by `p`'s + /// extension, if the header is not already present. + /// Returns true if the header was added. + pub fn add_header_if_missing(&self, p: &path::Path) -> Result<bool, AddHeaderError> { + let err_mapper = |e| AddHeaderError::IoError(p.to_path_buf(), e); + let contents = fs::read_to_string(p).map_err(err_mapper)?; + + if self.header_present(&mut contents.as_bytes()).map_err(err_mapper)? { + return Ok(false); + } + + let mut effective_header = header_delimiters(p) + .ok_or_else(|| AddHeaderError::UnknownExtension(p.to_path_buf())) + .map(|d| wrap_header(&self.header, d))?; + + let mut after_header = contents.as_str(); + // check for a magic first line + if let Some((first_line, rest)) = contents.split_once('\n') { + if MAGIC_FIRST_LINES.iter().any(|l| first_line.contains(l)) { + let mut first_line = first_line.to_string(); + first_line.push('\n'); + effective_header.insert_str(0, &first_line); + after_header = rest; + } + } + + // write the license + let mut f = + fs::OpenOptions::new().write(true).truncate(true).open(p).map_err(err_mapper)?; + f.write_all(effective_header.as_bytes()).map_err(err_mapper)?; + // newline to separate the header from previous contents + f.write_all("\n".as_bytes()).map_err(err_mapper)?; + f.write_all(after_header.as_bytes()).map_err(err_mapper)?; + + Ok(true) + } +} + +/// Errors that can occur when adding a header +#[derive(Debug, thiserror::Error)] +pub enum AddHeaderError { + #[error("I/O error at {0:?}: {1}")] + IoError(path::PathBuf, io::Error), + #[error("Unknown file extension: {0:?}")] + UnknownExtension(path::PathBuf), +} + +/// Checks for headers in files, like licenses or author attribution. +pub trait HeaderChecker: Send + Clone { + /// Return true if the file has the desired header, false otherwise. + fn check(&self, file: &mut impl io::Read) -> io::Result<bool>; +} + +/// Checks for a in the first several lines of each file. +#[derive(Clone)] +pub struct SingleLineChecker { + /// Pattern to do a substring match on in each of the first `max_lines` lines of the file + pattern: String, + /// Number of lines to search through + max_lines: usize, +} + +impl SingleLineChecker { + /// Construct a `SingleLineChecker` that looks for `pattern` in the first `max_lines` of a file. + pub(crate) fn new(pattern: String, max_lines: usize) -> Self { + Self { pattern, max_lines } + } +} + +impl HeaderChecker for SingleLineChecker { + fn check(&self, input: &mut impl io::Read) -> io::Result<bool> { + let mut reader = io::BufReader::new(input); + let mut lines_read = 0; + // reuse buffer to minimize allocation + let mut line = String::new(); + // only read the first bit of the file + while lines_read < self.max_lines { + line.clear(); + let bytes = reader.read_line(&mut line)?; + if bytes == 0 { + // EOF + return Ok(false); + } + lines_read += 1; + + if line.contains(&self.pattern) { + return Ok(true); + } + } + + Ok(false) + } +} + +#[derive(Copy, Clone)] +enum CheckStatus { + MisMatchedHeader, + BinaryFile, +} + +#[derive(Clone)] +struct FileResult { + path: path::PathBuf, + status: CheckStatus, +} + +#[derive(Clone, Default)] +pub struct FileResults { + pub mismatched_files: Vec<path::PathBuf>, + pub binary_files: Vec<path::PathBuf>, +} + +impl FileResults { + pub fn has_failure(&self) -> bool { + !self.mismatched_files.is_empty() || !self.binary_files.is_empty() + } +} + +impl FromIterator<FileResult> for FileResults { + fn from_iter<I>(iter: I) -> FileResults + where + I: IntoIterator<Item = FileResult>, + { + let mut results = FileResults::default(); + for result in iter { + match result.status { + CheckStatus::MisMatchedHeader => results.mismatched_files.push(result.path), + CheckStatus::BinaryFile => results.binary_files.push(result.path), + } + } + results + } +} + +/// Recursively check for `header` in every file in `root` that matches `path_predicate`. +/// +/// Returns a [`FileResults`] object containing the paths without headers detected. +pub fn check_headers_recursively( + root: &path::Path, + path_predicate: impl Fn(&path::Path) -> bool, + header: Header<impl HeaderChecker + 'static>, + num_threads: usize, +) -> Result<FileResults, CheckHeadersRecursivelyError> { + let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>(); + let (result_tx, result_rx) = crossbeam::channel::unbounded(); + + // spawn a few threads to handle files in parallel + let handles = (0..num_threads) + .map(|_| { + let path_rx = path_rx.clone(); + let result_tx = result_tx.clone(); + let header = header.clone(); + thread::spawn(move || { + for p in path_rx { + match fs::File::open(&p).and_then(|mut f| header.header_present(&mut f)) { + Ok(header_present) => { + if header_present { + // no op + } else { + let res = + FileResult { path: p, status: CheckStatus::MisMatchedHeader }; + result_tx.send(Ok(res)).unwrap(); + } + } + Err(e) if e.kind() == io::ErrorKind::InvalidData => { + // Binary file - add to ignore in license.rs + let res = FileResult { path: p, status: CheckStatus::BinaryFile }; + result_tx.send(Ok(res)).unwrap(); + } + Err(e) => result_tx + .send(Err(CheckHeadersRecursivelyError::IoError(p, e))) + .unwrap(), + } + } + + // no more files + }) + }) + .collect::<Vec<thread::JoinHandle<()>>>(); + // make sure result channel closes when threads complete + drop(result_tx); + + find_files(root, path_predicate, path_tx)?; + + let res: FileResults = result_rx.into_iter().collect::<Result<_, _>>()?; + + for h in handles { + h.join().unwrap(); + } + + Ok(res) +} + +/// Errors that can occur when checking for headers recursively +#[derive(Debug, thiserror::Error)] +pub enum CheckHeadersRecursivelyError { + #[error("I/O error at {0:?}: {1}")] + IoError(path::PathBuf, io::Error), + #[error("Walkdir error: {0}")] + WalkdirError(#[from] walkdir::Error), +} + +/// Add the provided `header` to any file in `root` that matches `path_predicate` and that doesn't +/// already have a header as determined by `checker`. +/// Returns a list of paths that had headers added. +pub fn add_headers_recursively( + root: &path::Path, + path_predicate: impl Fn(&path::Path) -> bool, + header: Header<impl HeaderChecker>, +) -> Result<Vec<path::PathBuf>, AddHeadersRecursivelyError> { + // likely no need for threading since adding headers is only done occasionally + let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>(); + find_files(root, path_predicate, path_tx)?; + + path_rx + .into_iter() + // keep the errors, or the ones with added headers + .filter_map(|p| { + match header.add_header_if_missing(&p).map_err(|e| match e { + AddHeaderError::IoError(p, e) => AddHeadersRecursivelyError::IoError(p, e), + AddHeaderError::UnknownExtension(e) => { + AddHeadersRecursivelyError::UnknownExtension(e) + } + }) { + Ok(added) => { + if added { + Some(Ok(p)) + } else { + None + } + } + Err(e) => Some(Err(e)), + } + }) + .collect::<Result<Vec<_>, _>>() +} + +/// Errors that can occur when adding a header recursively +#[derive(Debug, thiserror::Error)] +pub enum AddHeadersRecursivelyError { + #[error("I/O error at {0:?}: {1}")] + IoError(path::PathBuf, io::Error), + #[error("Walkdir error: {0}")] + WalkdirError(#[from] walkdir::Error), + #[error("Unknown file extension: {0:?}")] + UnknownExtension(path::PathBuf), +} + +/// Find all files starting from `root` that do not match the globs in `ignore`, publishing the +/// resulting paths into `dest`. +fn find_files( + root: &path::Path, + path_predicate: impl Fn(&path::Path) -> bool, + dest: crossbeam::channel::Sender<path::PathBuf>, +) -> Result<(), walkdir::Error> { + for r in walkdir::WalkDir::new(root).into_iter() { + let entry = r?; + if entry.path().is_dir() || !path_predicate(entry.path()) { + continue; + } + dest.send(entry.into_path()).unwrap() + } + + Ok(()) +} + +/// Prepare a header for inclusion in a particular file syntax by wrapping it with +/// comment characters as per the provided `delim`. +fn wrap_header(orig_header: &str, delim: HeaderDelimiters) -> String { + let mut out = String::new(); + + if !delim.first_line.is_empty() { + out.push_str(delim.first_line); + out.push('\n'); + } + + // assumes header uses \n + for line in orig_header.split('\n') { + out.push_str(delim.content_line_prefix); + out.push_str(line); + // Remove any trailing whitespaces (excluding newlines) from `content_line_prefix + line`. + // For example, if `content_line_prefix` is `// ` and `line` is empty, the resulting string + // should be truncated to `//`. + out.truncate(out.trim_end_matches([' ', '\t']).len()); + out.push('\n'); + } + + if !delim.last_line.is_empty() { + out.push_str(delim.last_line); + out.push('\n'); + } + + out +} + +/// Returns the header prefix line, content line prefix, and suffix line for the extension of the +/// provided path, or `None` if the extension is not recognized. +fn header_delimiters(p: &path::Path) -> Option<HeaderDelimiters> { + match p + .extension() + // if the extension isn't UTF-8, oh well + .and_then(|os_str| os_str.to_str()) + .unwrap_or("") + { + "c" | "h" | "gv" | "java" | "scala" | "kt" | "kts" => Some(("/*", " * ", " */")), + "js" | "mjs" | "cjs" | "jsx" | "tsx" | "css" | "scss" | "sass" | "ts" => { + Some(("/**", " * ", " */")) + } + "cc" | "cpp" | "cs" | "go" | "hcl" | "hh" | "hpp" | "m" | "mm" | "proto" | "rs" + | "swift" | "dart" | "groovy" | "v" | "sv" => Some(("", "// ", "")), + "py" | "sh" | "yaml" | "yml" | "dockerfile" | "rb" | "gemfile" | "tcl" | "tf" | "bzl" + | "pl" | "pp" | "build" => Some(("", "# ", "")), + "el" | "lisp" => Some(("", ";; ", "")), + "erl" => Some(("", "% ", "")), + "hs" | "lua" | "sql" | "sdl" => Some(("", "-- ", "")), + "html" | "xml" | "vue" | "wxi" | "wxl" | "wxs" => Some(("<!--", " ", "-->")), + "php" => Some(("", "// ", "")), + "ml" | "mli" | "mll" | "mly" => Some(("(**", " ", "*)")), + // also handle whole filenames if extensions didn't match + _ => match p.file_name().and_then(|os_str| os_str.to_str()).unwrap_or("") { + "Dockerfile" => Some(("", "# ", "")), + _ => None, + }, + } + .map(|(first_line, content_line_prefix, last_line)| HeaderDelimiters { + first_line, + content_line_prefix, + last_line, + }) +} + +/// Delimiters to use around and inside a header for a particular file syntax. +#[derive(Clone, Copy)] +struct HeaderDelimiters { + /// Line to prepend before the header + first_line: &'static str, + /// Prefix before each line of the header itself + content_line_prefix: &'static str, + /// Line to append after the header + last_line: &'static str, +} + +const MAGIC_FIRST_LINES: [&str; 8] = [ + "#!", // shell script + "<?xml", // XML declaratioon + "<!doctype", // HTML doctype + "# encoding:", // Ruby encoding + "# frozen_string_literal:", // Ruby interpreter instruction + "<?php", // PHP opening tag + "# escape", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives + "# syntax", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives +]; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn single_line_checker_finds_header_when_present() { + let input = r#"foo + some license + bar"#; + + assert!(test_header().checker.check(&mut input.as_bytes()).unwrap()); + } + + #[test] + fn single_line_checker_doesnt_find_header_when_missing() { + let input = r#"foo + wrong license + bar"#; + + assert!(!test_header().checker.check(&mut input.as_bytes()).unwrap()); + } + + #[test] + fn single_line_checker_throws_error_when_missing_and_file_is_non_utf8() { + let input = b"foo + \x00\xff + bar"; + + assert_eq!( + io::ErrorKind::InvalidData, + test_header().checker.check(&mut input.as_slice()).unwrap_err().kind() + ); + } + + #[test] + fn single_line_checker_doesnt_panic_when_file_is_non_utf8() { + let inputs: [&'static [u8]; 3] = [ + b"foo + \x00\xff + bar", + b"foo + some license + \x00\xff + bar", + b"foo + \x00\xff + some license + bar", + ]; + + for mut input in inputs { + // Output is not defined for non-utf-8 files, but we should handle them with grace + let _ = test_header().checker.check(&mut input); + } + } + + #[test] + fn adds_header_with_empty_delimiters() { + let file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap(); + fs::write(file.path(), r#"not a license"#).unwrap(); + + test_header().add_header_if_missing(file.path()).unwrap(); + + assert_eq!( + "// some license etc etc etc + +not a license", + fs::read_to_string(file.path()).unwrap() + ); + } + + #[test] + fn adds_header_with_nonempty_delimiters() { + let file = tempfile::Builder::new().suffix(".c").tempfile().unwrap(); + fs::write(file.path(), r#"not a license"#).unwrap(); + + test_header().add_header_if_missing(file.path()).unwrap(); + + assert_eq!( + "/* + * some license etc etc etc + */ + +not a license", + fs::read_to_string(file.path()).unwrap() + ); + } + + #[test] + fn adds_header_trim_trailing_whitespace() { + let file = tempfile::Builder::new().suffix(".c").tempfile().unwrap(); + fs::write(file.path(), r#"not a license"#).unwrap(); + + test_header_with_blank_lines_and_trailing_whitespace() + .add_header_if_missing(file.path()) + .unwrap(); + + assert_eq!( + "/* + * some license + * line with trailing whitespace. + * + * etc + */ + +not a license", + fs::read_to_string(file.path()).unwrap() + ); + } + + #[test] + fn doesnt_add_header_when_already_present() { + let file = tempfile::Builder::new().suffix(".rs").tempfile().unwrap(); + let initial_content = r#" + // some license etc etc etc already present + not a license"#; + fs::write(file.path(), initial_content).unwrap(); + + test_header().add_header_if_missing(file.path()).unwrap(); + + assert_eq!(initial_content, fs::read_to_string(file.path()).unwrap()); + } + + #[test] + fn adds_header_after_magic_first_line() { + let file = tempfile::Builder::new().suffix(".xml").tempfile().unwrap(); + fs::write( + file.path(), + r#"<?xml version="1.0" encoding="UTF-8"?> +<root /> +"#, + ) + .unwrap(); + + test_header().add_header_if_missing(file.path()).unwrap(); + + assert_eq!( + r#"<?xml version="1.0" encoding="UTF-8"?> +<!-- + some license etc etc etc +--> + +<root /> +"#, + fs::read_to_string(file.path()).unwrap() + ); + } + + fn test_header() -> Header<SingleLineChecker> { + Header::new( + SingleLineChecker::new("some license".to_string(), 100), + r#"some license etc etc etc"#.to_string(), + ) + } + + fn test_header_with_blank_lines_and_trailing_whitespace() -> Header<SingleLineChecker> { + Header::new( + SingleLineChecker::new("some license".to_string(), 100), + "some license\nline with trailing whitespace. \n\netc".to_string(), + ) + } +} |