Skip to content

Commit

Permalink
Formatter: Add SourceType to context to enable special formatting for…
Browse files Browse the repository at this point in the history
… stub files (#6331)

**Summary** This adds the information whether we're in a .py python
source file or in a .pyi stub file to enable people working on #5822 and
related issues.

I'm not completely happy with `Default` for something that depends on
the input.

**Test Plan** None, this is currently unused, i'm leaving this to first
implementation of stub file specific formatting.

---------

Co-authored-by: Micha Reiser <micha@reiser.io>
  • Loading branch information
konstin and MichaReiser authored Aug 4, 2023
1 parent fe97a2a commit 1031bb6
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 97 deletions.
8 changes: 3 additions & 5 deletions crates/ruff/src/rules/pycodestyle/rules/type_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@ pub(crate) fn type_comparison(checker: &mut Checker, compare: &ast::ExprCompare)
}

// Left-hand side must be, e.g., `type(obj)`.
let Expr::Call(ast::ExprCall {
func, ..
}) = left else {
let Expr::Call(ast::ExprCall { func, .. }) = left else {
continue;
};

let Expr::Name(ast::ExprName { id, .. }) = func.as_ref() else {
continue;
};
continue;
};

if !(id == "type" && checker.semantic().is_builtin("type")) {
continue;
Expand Down
5 changes: 3 additions & 2 deletions crates/ruff_benchmark/benches/formatter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruff_benchmark::{TestCase, TestCaseSpeed, TestFile, TestFileDownloadError};
use ruff_python_formatter::{format_module, PyFormatOptions};
use std::path::Path;
use std::time::Duration;

#[cfg(target_os = "windows")]
Expand Down Expand Up @@ -51,8 +52,8 @@ fn benchmark_formatter(criterion: &mut Criterion) {
&case,
|b, case| {
b.iter(|| {
format_module(case.code(), PyFormatOptions::default())
.expect("Formatting to succeed")
let options = PyFormatOptions::from_extension(Path::new(case.name()));
format_module(case.code(), options).expect("Formatting to succeed")
});
},
);
Expand Down
13 changes: 4 additions & 9 deletions crates/ruff_cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,20 @@ fn format(files: &[PathBuf]) -> Result<ExitStatus> {
internal use only."
);

let format_code = |code: &str| {
// dummy, to check that the function was actually called
let contents = code.replace("# DEL", "");
// real formatting that is currently a passthrough
format_module(&contents, PyFormatOptions::default())
};

match &files {
// Check if we should read from stdin
[path] if path == Path::new("-") => {
let unformatted = read_from_stdin()?;
let formatted = format_code(&unformatted)?;
let options = PyFormatOptions::from_extension(Path::new("stdin.py"));
let formatted = format_module(&unformatted, options)?;
stdout().lock().write_all(formatted.as_code().as_bytes())?;
}
_ => {
for file in files {
let unformatted = std::fs::read_to_string(file)
.with_context(|| format!("Could not read {}: ", file.display()))?;
let formatted = format_code(&unformatted)?;
let options = PyFormatOptions::from_extension(file);
let formatted = format_module(&unformatted, options)?;
std::fs::write(file, formatted.as_code().as_bytes())
.with_context(|| format!("Could not write to {}, exiting", file.display()))?;
}
Expand Down
58 changes: 28 additions & 30 deletions crates/ruff_dev/src/format_dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,10 @@ fn format_dev_project(

// TODO(konstin): The assumptions between this script (one repo) and ruff (pass in a bunch of
// files) mismatch.
let options = BlackOptions::from_file(&files[0])?.to_py_format_options();
let black_options = BlackOptions::from_file(&files[0])?;
debug!(
parent: None,
"Options for {}: {options:?}",
"Options for {}: {black_options:?}",
files[0].display()
);

Expand All @@ -398,7 +398,7 @@ fn format_dev_project(
paths
.into_par_iter()
.map(|dir_entry| {
let result = format_dir_entry(dir_entry, stability_check, write, &options);
let result = format_dir_entry(dir_entry, stability_check, write, &black_options);
pb_span.pb_inc(1);
result
})
Expand Down Expand Up @@ -447,7 +447,7 @@ fn format_dir_entry(
dir_entry: Result<DirEntry, ignore::Error>,
stability_check: bool,
write: bool,
options: &PyFormatOptions,
options: &BlackOptions,
) -> anyhow::Result<(Result<Statistics, CheckFileError>, PathBuf), Error> {
let dir_entry = match dir_entry.context("Iterating the files in the repository failed") {
Ok(dir_entry) => dir_entry,
Expand All @@ -460,27 +460,27 @@ fn format_dir_entry(
}

let file = dir_entry.path().to_path_buf();
let options = options.to_py_format_options(&file);
// Handle panics (mostly in `debug_assert!`)
let result =
match catch_unwind(|| format_dev_file(&file, stability_check, write, options.clone())) {
Ok(result) => result,
Err(panic) => {
if let Some(message) = panic.downcast_ref::<String>() {
Err(CheckFileError::Panic {
message: message.clone(),
})
} else if let Some(&message) = panic.downcast_ref::<&str>() {
Err(CheckFileError::Panic {
message: message.to_string(),
})
} else {
Err(CheckFileError::Panic {
// This should not happen, but it can
message: "(Panic didn't set a string message)".to_string(),
})
}
let result = match catch_unwind(|| format_dev_file(&file, stability_check, write, options)) {
Ok(result) => result,
Err(panic) => {
if let Some(message) = panic.downcast_ref::<String>() {
Err(CheckFileError::Panic {
message: message.clone(),
})
} else if let Some(&message) = panic.downcast_ref::<&str>() {
Err(CheckFileError::Panic {
message: message.to_string(),
})
} else {
Err(CheckFileError::Panic {
// This should not happen, but it can
message: "(Panic didn't set a string message)".to_string(),
})
}
};
}
};
Ok((result, file))
}

Expand Down Expand Up @@ -833,18 +833,16 @@ impl BlackOptions {
Self::from_toml(&fs::read_to_string(&path)?, repo)
}

fn to_py_format_options(&self) -> PyFormatOptions {
let mut options = PyFormatOptions::default();
options
fn to_py_format_options(&self, file: &Path) -> PyFormatOptions {
PyFormatOptions::from_extension(file)
.with_line_width(
LineWidth::try_from(self.line_length).expect("Invalid line length limit"),
)
.with_magic_trailing_comma(if self.skip_magic_trailing_comma {
MagicTrailingComma::Ignore
} else {
MagicTrailingComma::Respect
});
options
})
}
}

Expand All @@ -868,7 +866,7 @@ mod tests {
"};
let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml"))
.unwrap()
.to_py_format_options();
.to_py_format_options(Path::new("code_inline.py"));
assert_eq!(options.line_width(), LineWidth::try_from(119).unwrap());
assert!(matches!(
options.magic_trailing_comma(),
Expand All @@ -887,7 +885,7 @@ mod tests {
"#};
let options = BlackOptions::from_toml(toml, Path::new("pyproject.toml"))
.unwrap()
.to_py_format_options();
.to_py_format_options(Path::new("code_inline.py"));
assert_eq!(options.line_width(), LineWidth::try_from(130).unwrap());
assert!(matches!(
options.magic_trailing_comma(),
Expand Down
34 changes: 34 additions & 0 deletions crates/ruff_python_ast/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ruff_text_size::{TextRange, TextSize};
use std::path::Path;

pub mod all;
pub mod call_path;
Expand Down Expand Up @@ -49,3 +50,36 @@ where
T::range(self)
}
}

#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PySourceType {
#[default]
Python,
Stub,
Jupyter,
}

impl PySourceType {
pub const fn is_python(&self) -> bool {
matches!(self, PySourceType::Python)
}

pub const fn is_stub(&self) -> bool {
matches!(self, PySourceType::Stub)
}

pub const fn is_jupyter(&self) -> bool {
matches!(self, PySourceType::Jupyter)
}
}

impl From<&Path> for PySourceType {
fn from(path: &Path) -> Self {
match path.extension() {
Some(ext) if ext == "pyi" => PySourceType::Stub,
Some(ext) if ext == "ipynb" => PySourceType::Jupyter,
_ => PySourceType::Python,
}
}
}
6 changes: 3 additions & 3 deletions crates/ruff_python_formatter/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ smallvec = { workspace = true }
thiserror = { workspace = true }

[dev-dependencies]
ruff_formatter = { path = "../ruff_formatter", features = ["serde"]}
ruff_formatter = { path = "../ruff_formatter", features = ["serde"] }

insta = { workspace = true, features = ["glob"] }
serde = { workspace = true }
Expand All @@ -43,8 +43,8 @@ similar = { workspace = true }
name = "ruff_python_formatter_fixtures"
path = "tests/fixtures.rs"
test = true
required-features = [ "serde" ]
required-features = ["serde"]

[features]
serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde"]
serde = ["dep:serde", "ruff_formatter/serde", "ruff_source_file/serde", "ruff_python_ast/serde"]
default = ["serde"]
18 changes: 7 additions & 11 deletions crates/ruff_python_formatter/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#![allow(clippy::print_stdout)]

use std::path::PathBuf;
use std::path::{Path, PathBuf};

use anyhow::{bail, Context, Result};
use clap::{command, Parser, ValueEnum};
use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode};

use ruff_formatter::SourceCode;
use ruff_python_index::CommentRangesBuilder;
use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode};

use crate::{format_node, PyFormatOptions};

Expand Down Expand Up @@ -37,7 +37,7 @@ pub struct Cli {
pub print_comments: bool,
}

pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result<String> {
pub fn format_and_debug_print(input: &str, cli: &Cli, source_type: &Path) -> Result<String> {
let mut tokens = Vec::new();
let mut comment_ranges = CommentRangesBuilder::default();

Expand All @@ -57,13 +57,9 @@ pub fn format_and_debug_print(input: &str, cli: &Cli) -> Result<String> {
let python_ast =
parse_tokens(tokens, Mode::Module, "<filename>").context("Syntax error in input")?;

let formatted = format_node(
&python_ast,
&comment_ranges,
input,
PyFormatOptions::default(),
)
.context("Failed to format node")?;
let options = PyFormatOptions::from_extension(source_type);
let formatted = format_node(&python_ast, &comment_ranges, input, options)
.context("Failed to format node")?;
if cli.print_ir {
println!("{}", formatted.document().display(SourceCode::new(input)));
}
Expand Down
14 changes: 5 additions & 9 deletions crates/ruff_python_formatter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ mod tests {
use ruff_python_index::CommentRangesBuilder;
use ruff_python_parser::lexer::lex;
use ruff_python_parser::{parse_tokens, Mode};
use std::path::Path;

/// Very basic test intentionally kept very similar to the CLI
#[test]
Expand Down Expand Up @@ -321,15 +322,10 @@ with [
let comment_ranges = comment_ranges.finish();

// Parse the AST.
let python_ast = parse_tokens(tokens, Mode::Module, "<filename>").unwrap();

let formatted = format_node(
&python_ast,
&comment_ranges,
src,
PyFormatOptions::default(),
)
.unwrap();
let source_path = "code_inline.py";
let python_ast = parse_tokens(tokens, Mode::Module, source_path).unwrap();
let options = PyFormatOptions::from_extension(Path::new(source_path));
let formatted = format_node(&python_ast, &comment_ranges, src, options).unwrap();

// Uncomment the `dbg` to print the IR.
// Use `dbg_write!(f, []) instead of `write!(f, [])` in your formatting code to print some IR
Expand Down
6 changes: 4 additions & 2 deletions crates/ruff_python_formatter/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::{stdout, Read, Write};
use std::path::Path;
use std::{fs, io};

use anyhow::{bail, Context, Result};
Expand All @@ -25,7 +26,8 @@ fn main() -> Result<()> {
);
}
let input = read_from_stdin()?;
let formatted = format_and_debug_print(&input, &cli)?;
// It seems reasonable to give this a dummy name
let formatted = format_and_debug_print(&input, &cli, Path::new("stdin.py"))?;
if cli.check {
if formatted == input {
return Ok(());
Expand All @@ -37,7 +39,7 @@ fn main() -> Result<()> {
for file in &cli.files {
let input = fs::read_to_string(file)
.with_context(|| format!("Could not read {}: ", file.display()))?;
let formatted = format_and_debug_print(&input, &cli)?;
let formatted = format_and_debug_print(&input, &cli, file)?;
match cli.emit {
Some(Emit::Stdout) => stdout().lock().write_all(formatted.as_bytes())?,
None | Some(Emit::Files) => {
Expand Down
Loading

0 comments on commit 1031bb6

Please sign in to comment.