diff --git a/modules/idlgen/core/src/main/scala/higherkindness/mu/rpc/idlgen/proto/ProtoSrcGenerator.scala b/modules/idlgen/core/src/main/scala/higherkindness/mu/rpc/idlgen/proto/ProtoSrcGenerator.scala index efe585c0d..a4e5492d1 100644 --- a/modules/idlgen/core/src/main/scala/higherkindness/mu/rpc/idlgen/proto/ProtoSrcGenerator.scala +++ b/modules/idlgen/core/src/main/scala/higherkindness/mu/rpc/idlgen/proto/ProtoSrcGenerator.scala @@ -20,6 +20,7 @@ import java.io.File import cats.effect.{IO, Sync} import cats.syntax.functor._ +import cats.syntax.option._ import higherkindness.mu.rpc.idlgen.Model.{ CompressionTypeGen, GzipGen, @@ -53,19 +54,21 @@ object ProtoSrcGenerator { options: String*): Option[(String, Seq[String])] = getCode[IO](inputFile).map(Some(_)).unsafeRunSync - def withImports(self: String): String = - (self.split("\n", 2).toList match { - case h :: t => imports(h) :: t - case a => a - }).mkString("\n") + def withImports(lines: List[String]): List[String] = + lines match { + case h :: t => + // first line of file is package declaration + h :: imports(t) ++ t + case a => a + } val copRegExp: Regex = """((Cop\[)(((\w+)((\[)(\w+)(\]))?(\s)?(\:\:)(\s)?)+)(TNil)(\]))""".r val cleanCop: String => String = _.replace("Cop[", "").replace("::", ":+:").replace("TNil]", "CNil") - val withCoproducts: String => String = self => - copRegExp.replaceAllIn(self, m => cleanCop(m.matched)) + val withCoproducts: List[String] => List[String] = lines => + lines.map(line => copRegExp.replaceAllIn(line, m => cleanCop(m.matched))) val skeuomorphCompression: CompressionType = compressionTypeGen match { case GzipGen => CompressionType.Gzip @@ -79,19 +82,32 @@ object ProtoSrcGenerator { val printProtocol: higherkindness.skeuomorph.mu.Protocol[Mu[MuF]] => String = higherkindness.skeuomorph.mu.print.proto.print + val splitLines: String => List[String] = _.split("\n").toList + private def getCode[F[_]: Sync](file: File): F[(String, Seq[String])] = parseProto[F, Mu[ProtobufF]] .parse(ProtoSource(file.getName, file.getParent, Some(idlTargetDir.getCanonicalPath))) - .map( - protocol => - getPath(protocol) -> Seq( - (parseProtocol andThen printProtocol andThen withImports andThen withCoproducts)( - protocol))) + .map(protocol => + getPath(protocol) -> + (parseProtocol andThen printProtocol andThen splitLines andThen withCoproducts andThen withImports)( + protocol)) private def getPath(p: Protocol[Mu[ProtobufF]]): String = s"${p.pkg.replace('.', '/')}/${p.name}$ScalaFileExtension" - def imports(pkg: String): String = - s"$pkg\nimport higherkindness.mu.rpc.protocol._\nimport fs2.Stream\nimport shapeless.{:+:, CNil}" + def imports(fileLines: List[String]): List[String] = { + List( + "import higherkindness.mu.rpc.protocol._".some, + if (fileLines.exists(_.contains("Stream[F,"))) + "import fs2.Stream".some + else + None, + if (fileLines.exists(_.contains(":+:"))) + "import shapeless.{:+:, CNil}".some + else + None + ).flatten + } + } } diff --git a/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_no.proto b/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_no.proto new file mode 100644 index 000000000..56a2d6374 --- /dev/null +++ b/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_no.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package com.proto; + +message MyRequest { + string a = 1; +} + +message MyResponse { + string a = 1; +} + +service MyService { + rpc GetThing (MyRequest) returns (MyResponse) {} +} diff --git a/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_yes.proto b/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_yes.proto new file mode 100644 index 000000000..385ef27e7 --- /dev/null +++ b/modules/idlgen/core/src/test/resources/proto/streaming_no_shapeless_yes.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package com.proto; + +message MyRequest { + oneof a { + int32 b = 1; + string c = 2; + } +} + +message MyResponse { + string a = 1; +} + +service MyService { + rpc GetThing (MyRequest) returns (MyResponse) {} +} diff --git a/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_no.proto b/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_no.proto new file mode 100644 index 000000000..9d6d39b6b --- /dev/null +++ b/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_no.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package com.proto; + +message MyRequest { + string a = 1; +} + +message MyResponse { + string a = 1; +} + +service MyService { + rpc GetThing (MyRequest) returns (stream MyResponse) {} +} diff --git a/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_yes.proto b/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_yes.proto new file mode 100644 index 000000000..be06b80c7 --- /dev/null +++ b/modules/idlgen/core/src/test/resources/proto/streaming_yes_shapeless_yes.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package com.proto; + +message MyRequest { + oneof a { + int32 b = 1; + string c = 2; + } +} + +message MyResponse { + string a = 1; +} + +service MyService { + rpc GetThing (MyRequest) returns (stream MyResponse) {} +} diff --git a/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/SrcGenTests.scala b/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/AvroSrcGenTests.scala similarity index 96% rename from modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/SrcGenTests.scala rename to modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/AvroSrcGenTests.scala index c7cd4cbc0..60311cabd 100644 --- a/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/SrcGenTests.scala +++ b/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/AvroSrcGenTests.scala @@ -23,7 +23,7 @@ import higherkindness.mu.rpc.idlgen.avro._ import org.scalatestplus.scalacheck.Checkers import org.scalacheck.Prop.forAll -class SrcGenTests extends RpcBaseTestSuite with Checkers { +class AvroSrcGenTests extends RpcBaseTestSuite with Checkers { "Avro Scala Generator" should { diff --git a/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/ProtoSrcGenTests.scala b/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/ProtoSrcGenTests.scala index 674741ca0..4cc8e1207 100644 --- a/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/ProtoSrcGenTests.scala +++ b/modules/idlgen/core/src/test/scala/higherkindness/mu/rpc/idlgen/ProtoSrcGenTests.scala @@ -21,28 +21,62 @@ import java.io.File import higherkindness.mu.rpc.common.RpcBaseTestSuite import higherkindness.mu.rpc.idlgen.proto.ProtoSrcGenerator import higherkindness.mu.rpc.idlgen.Model.{NoCompressionGen, UseIdiomaticEndpoints} +import org.scalatest.OptionValues -class ProtoSrcGenTests extends RpcBaseTestSuite { +class ProtoSrcGenTests extends RpcBaseTestSuite with OptionValues { - val module: String = new java.io.File(".").getCanonicalPath - val protoFile: File = new File(module + "/src/test/resources/proto/book.proto") + val module: String = new java.io.File(".").getCanonicalPath + def protoFile(filename: String): File = + new File(s"$module/src/test/resources/proto/$filename.proto") "Proto Scala Generator" should { "generate correct Scala classes" in { - val result: Option[(String, Seq[String])] = + val result: Option[(String, String)] = ProtoSrcGenerator .build(NoCompressionGen, UseIdiomaticEndpoints(false), new java.io.File(".")) - .generateFrom(files = Set(protoFile), serializationType = "", options = "") - .map(t => (t._2, t._3.map(_.clean))) + .generateFrom(files = Set(protoFile("book")), serializationType = "", options = "") + .map(t => (t._2, t._3.mkString("\n").clean)) .headOption - result shouldBe Some(("com/proto/book.scala", Seq(expectation.clean))) + result shouldBe Some(("com/proto/book.scala", bookExpectation.clean)) } + + case class ImportsTestCase( + protoFilename: String, + shouldIncludeFS2Import: Boolean, + shouldIncludeShapelessImport: Boolean + ) + + for (test <- List( + ImportsTestCase("streaming_no_shapeless_no", false, false), + ImportsTestCase("streaming_yes_shapeless_no", true, false), + ImportsTestCase("streaming_no_shapeless_yes", false, true), + ImportsTestCase("streaming_yes_shapeless_yes", true, true) + )) { + + s"include the correct imports (${test.protoFilename})" in { + + val result: Option[String] = + ProtoSrcGenerator + .build(NoCompressionGen, UseIdiomaticEndpoints(false), new java.io.File(".")) + .generateFrom( + files = Set(protoFile(test.protoFilename)), + serializationType = "", + options = "") + .map(_._3.mkString("\n")) + .headOption + + assert(result.value.contains("import fs2.") == test.shouldIncludeFS2Import) + assert(result.value.contains("import shapeless.") == test.shouldIncludeShapelessImport) + } + + } + } - val expectation = + val bookExpectation = """package com.proto | |import higherkindness.mu.rpc.protocol._