diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala index 8317e87840..eda6573f4f 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/CliOptions.scala @@ -9,6 +9,7 @@ import org.scalafmt.sysops.OsSpecific import java.io.InputStream import java.io.PrintStream +import java.io.PrintWriter import java.nio.file.Files import java.nio.file.NoSuchFileException import java.nio.file.Path @@ -35,23 +36,26 @@ object CliOptions { * directly from main. */ def auto(parsed: CliOptions): CliOptions = { - val usesOut = parsed.stdIn || parsed.writeMode.usesOut - val auxOut = - if (parsed.noStdErr || !usesOut) parsed.common.out else parsed.common.err - - parsed.copy(common = - parsed.common.copy( - out = guardPrintStream(parsed.quiet && !usesOut)(parsed.common.out), - info = guardPrintStream(parsed.quiet || usesOut)(auxOut), - debug = guardPrintStream(parsed.quiet)( - if (parsed.debug) auxOut else parsed.common.debug, - ), - err = guardPrintStream(parsed.quiet)(parsed.common.err), - ), + val info: Output.StreamOrWriter = + if (parsed.quiet) Output.NoopStream + else { + val usesOut = parsed.stdIn || parsed.writeMode.usesOut + new Output.StreamOrWriter.Stream( + if (parsed.noStdErr || !usesOut) parsed.common.out + else parsed.common.err, + ) + } + val common = parsed.common.copy( + out = guardPrintStream(parsed.quiet && !parsed.stdIn)(parsed.common.out), + info = info, + debug = (if (parsed.debug) info else Output.NoopStream).printWriter, + err = guardPrintStream(parsed.quiet)(parsed.common.err), ) + + parsed.copy(common = common) } - private def guardPrintStream(p: => Boolean)( + private def guardPrintStream(p: Boolean)( candidate: PrintStream, ): PrintStream = if (p) Output.NoopStream.printStream else candidate @@ -62,8 +66,8 @@ case class CommonOptions( out: PrintStream = System.out, in: InputStream = System.in, err: PrintStream = System.err, - debug: PrintStream = Output.NoopStream.printStream, - info: PrintStream = Output.NoopStream.printStream, + debug: PrintWriter = Output.NoopStream.printWriter, + info: Output.StreamOrWriter = Output.NoopStream, ) { private[cli] lazy val workingDirectory: AbsoluteFile = cwd .getOrElse(AbsoluteFile.userDir) diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala index 68f403fbda..5670df036d 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/Output.scala @@ -1,10 +1,33 @@ package org.scalafmt.cli import java.io._ +import java.nio.charset._ object Output { - object NoopStream extends OutputStream { + class WriterStream( + private[scalafmt] val writer: Writer, + charset: Charset = StandardCharsets.UTF_8, + ) extends OutputStream with StreamOrWriter { + + override def write(b: Int): Unit = writer.write(b & 0xf) + override def write(b: Array[Byte]): Unit = writer + .write(new String(b, charset)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = writer + .write(new String(b, off, len, charset)) + + override def flush(): Unit = writer.flush() + override def close(): Unit = writer.close() + + def outputStream: OutputStream = this + override def printStream: PrintStream = new PrintStream(this) + override def printWriter: PrintWriter = writer match { + case x: PrintWriter => x + case _ => new PrintWriter(writer) + } + } + + object NoopStream extends OutputStream with StreamOrWriter { self => override def write(b: Int): Unit = () @@ -18,4 +41,18 @@ object Output { val streamWriter = new OutputStreamWriter(self) } + trait StreamOrWriter { + def outputStream: OutputStream + def printStream: PrintStream + def printWriter: PrintWriter + } + + object StreamOrWriter { + class Stream(val obj: PrintStream) extends StreamOrWriter { + override def outputStream: OutputStream = obj + override def printStream: PrintStream = obj + override def printWriter: PrintWriter = new PrintWriter(obj) + } + } + } diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala index dfb0898d6c..a0abc94b62 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtCliReporter.scala @@ -41,11 +41,10 @@ class ScalafmtCliReporter(options: CliOptions) extends ScalafmtReporter { override def parsedConfig(config: Path, scalafmtVersion: String): Unit = options.common.debug.println(s"parsed config (v$scalafmtVersion): $config") - override def downloadWriter(): PrintWriter = - new PrintWriter(options.common.info) + override def downloadWriter(): PrintWriter = options.common.info.printWriter override def downloadOutputStreamWriter(): OutputStreamWriter = - new OutputStreamWriter(options.common.info) + new OutputStreamWriter(options.common.info.outputStream) } private class FailedToFormat(filename: String, cause: Throwable) diff --git a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala index b5ec5845ab..f75677d59d 100644 --- a/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala +++ b/scalafmt-cli/src/main/scala/org/scalafmt/cli/ScalafmtRunner.scala @@ -4,7 +4,6 @@ import org.scalafmt.Error import org.scalafmt.sysops.AbsoluteFile import org.scalafmt.sysops.BatchPathFinder -import java.io.OutputStreamWriter import java.nio.file.Path trait ScalafmtRunner { @@ -19,7 +18,7 @@ trait ScalafmtRunner { msg: String, ): TermDisplay = { val termDisplay = new TermDisplay( - new OutputStreamWriter(options.common.info), + options.common.info.printWriter, fallbackMode = options.nonInteractive || TermDisplay.defaultFallbackMode, ) if ( diff --git a/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala b/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala index b7caeda755..0d5a1f41bc 100644 --- a/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala +++ b/scalafmt-tests/src/test/scala/org/scalafmt/cli/CliOptionsTest.scala @@ -108,13 +108,17 @@ class CliOptionsTest extends FunSuite { test("write info to out if not writing to stdout") { val options = Cli.getConfig(Array.empty[String], baseCliOptionsWithOut).get - assertEquals(options.common.info, System.out) + assertEquals(options.common.info.printStream, System.out) } Seq("--stdin", "--stdout").foreach { arg => test(s"don't write info when using $arg") { val options = Cli.getConfig(Array(arg), baseCliOptionsWithOut).get - assertEquals(options.common.info, Output.NoopStream.printStream) + options.common.info match { + case x: Output.StreamOrWriter.Stream + if x.obj eq Output.NoopStream.printStream => + case x => fail(s"info should be writing to NoopStream: $x") + } } }