diff --git a/build.sbt b/build.sbt index ca93bd723..48c2e9ed3 100644 --- a/build.sbt +++ b/build.sbt @@ -116,6 +116,7 @@ lazy val docs = Compile / smithySpecs := Seq( (Compile / sourceDirectory).value / "smithy", (ThisBuild / baseDirectory).value / "sampleSpecs" / "test.smithy", + (ThisBuild / baseDirectory).value / "modules" / "guides" / "smithy" / "auth.smithy", (ThisBuild / baseDirectory).value / "sampleSpecs" / "hello.smithy", (ThisBuild / baseDirectory).value / "sampleSpecs" / "kvstore.smithy" ) @@ -621,7 +622,14 @@ lazy val http4s = projectMatrix if (virtualAxes.value.contains(CatsEffect2Axis)) moduleName.value + "-ce2" else moduleName.value - } + }, + Test / allowedNamespaces := Seq( + "smithy4s.hello" + ), + Test / smithySpecs := Seq( + (ThisBuild / baseDirectory).value / "sampleSpecs" / "hello.smithy" + ), + (Test / sourceGenerators) := Seq(genSmithyScala(Test).taskValue) ) .http4sPlatform(allJvmScalaVersions, jvmDimSettings) @@ -777,14 +785,19 @@ lazy val guides = projectMatrix .in(file("modules/guides")) .dependsOn(http4s) .settings( - Compile / allowedNamespaces := Seq("smithy4s.guides.hello"), + Compile / allowedNamespaces := Seq( + "smithy4s.guides.hello", + "smithy4s.guides.auth" + ), smithySpecs := Seq( - (ThisBuild / baseDirectory).value / "modules" / "guides" / "smithy" / "hello.smithy" + (ThisBuild / baseDirectory).value / "modules" / "guides" / "smithy" / "hello.smithy", + (ThisBuild / baseDirectory).value / "modules" / "guides" / "smithy" / "auth.smithy" ), (Compile / sourceGenerators) := Seq(genSmithyScala(Compile).taskValue), isCE3 := true, libraryDependencies ++= Seq( Dependencies.Http4s.emberServer.value, + Dependencies.Http4s.emberClient.value, Dependencies.Weaver.cats.value % Test ) ) diff --git a/modules/docs/markdown/06-guides/endpoint-middleware.md b/modules/docs/markdown/06-guides/endpoint-middleware.md new file mode 100644 index 000000000..33163c2ba --- /dev/null +++ b/modules/docs/markdown/06-guides/endpoint-middleware.md @@ -0,0 +1,306 @@ +--- +sidebar_label: Endpoint Specific Middleware +title: Endpoint Specific Middleware +--- + +It used to be the case that any middleware implemented for smithy4s services would have to operate at the http4s level, without any knowledge of smithy4s or access to the constructs to utilizes. + +As of version `0.17.x` of smithy4s, we have changed this by providing a new mechanism to build and provide middleware. This mechanism is aware of the smithy4s service and endpoints that are derived from your smithy specifications. As such, this unlocks the possibility to build middleware that utilizes and is compliant to the traits and shapes of your smithy specification. + +In this guide, we will show how you can implement a smithy4s middleware that is aware of the authentication traits in your specification and is able to implement authenticate on an endpoint-by-endpoint basis. This is useful if you have different or no authentication on one or more endpoints. + +## ServerEndpointMiddleware / ClientEndpointMiddleware + +`ServerEndpointMiddleware` is the interface that we have provided for implementing service middleware. For some use cases, you will need to use the full interface. However, for this guide and for many use cases, you will be able to rely on the simpler interface called `ServerEndpointMiddleware.Simple`. This interface requires a single method which looks as follows: + +```scala +def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): HttpApp[F] => HttpApp[F] +``` + +This means that given the hints for the service and a specific endpoint, our implementation will provide a transformation of an `HttpApp`. If you are not familiar with `Hints`, they are the smithy4s construct that represents Smithy Traits. They are called hints to avoid naming conflicts and confusion with Scala `trait`s. + +The `ClientEndpointMiddleware` interface is essentially the same as the one for `ServerEndpointMiddleware` with the exception that we are returning a transformation on `Client[F]` instead of `HttpApp[F]`. This looks like: + +```scala +def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): Client[F] => Client[F] +``` + +## Smithy Spec + +Let's look at the smithy specification that we will use for this guide. First, let's define the service. + +```kotlin +$version: "2" + +namespace smithy4s.guides.auth + +use alloy#simpleRestJson + +@simpleRestJson +@httpBearerAuth +service HelloWorldAuthService { + version: "1.0.0", + operations: [SayWorld, HealthCheck] + errors: [NotAuthorizedError] +} +``` + +Here we defined a service that has two operations, `SayWorld` and `HealthCheck`. We defined it such that any of these operations may return an `NotAuthorizedError`. Finally, we annotated the service with the `@httpBearerAuth` [trait](https://smithy.io/2.0/spec/authentication-traits.html#httpbearerauth-trait) to indicate that the service supports authentication via a bearer token. If you are using a different authentication scheme, you can still follow this guide and adapt it for your needs. You can find a full list of smithy-provided schemes [here](https://smithy.io/2.0/spec/authentication-traits.html). If none of the provided traits suit your use case, you can always create a custom trait too. + +Next, let's define our first operation, `SayWorld`: + +```kotlin +@readonly +@http(method: "GET", uri: "/hello", code: 200) +operation SayWorld { + output: World +} + +structure World { + message: String = "World !" +} +``` + +There is nothing authentication-specific defined with this operation, this means that the operation inherits the service-defined authentication scheme (`httpBearerAuth` in this case). Let's contrast this with the `HealthCheck` operation: + +```kotlin +@readonly +@http(method: "GET", uri: "/health", code: 200) +@auth([]) +operation HealthCheck { + output := { + @required + message: String + } +} +``` + +Notice that on this operation we have added the `@auth([])` trait with an empty array. This means that there is no authentication required for this endpoint. In other words, although the service defines an authentication scheme of `httpBearerAuth`, that scheme will not apply to this endpoint. + +Finally, let's define the `NotAuthorizedError` that will be returned when an authentication token is missing or invalid. + +```kotlin +@error("client") +@httpError(401) +structure NotAuthorizedError { + @required + message: String +} +``` + +There is nothing authentication specific about this error, this is a standard smithy http error that will have a 401 status code when returned. + +If you want to see the full smithy model we defined above, you can do so [here](https://github.com/disneystreaming/smithy4s/blob/main/modules/guides/smithy/auth.smithy). + +## Server-side Middleware + +To see the **full code** example of what we walk through below, go [here](https://github.com/disneystreaming/smithy4s/tree/main/modules/guides/src/smithy4s/guides/Auth.scala). + +We will create a server-side middleware that implements the authentication as defined in the smithy spec above. Let's start by creating a few classes that we will use in our middleware. + +```scala mdoc:invisible +import smithy4s.guides.auth._ +import cats.effect._ +import cats.implicits._ +import org.http4s.implicits._ +import org.http4s._ +import smithy4s.http4s.SimpleRestJsonBuilder +import smithy4s._ +import org.http4s.headers.Authorization +import smithy4s.http4s.ServerEndpointMiddleware +``` + +#### AuthChecker + +```scala mdoc:silent +case class ApiToken(value: String) + +trait AuthChecker { + def isAuthorized(token: ApiToken): IO[Boolean] +} + +object AuthChecker extends AuthChecker { + def isAuthorized(token: ApiToken): IO[Boolean] = { + IO.pure( + token.value.nonEmpty + ) // put your logic here, currently just makes sure the token is not empty + } +} +``` + +This is a simple class that we will use to check the validity of a given token. This will be more complex in your own service, but we are keeping it simple here since it is out of the scope of this article and implementations will vary widely depending on your specific application. + +#### The Inner Middleware Implementation + +This function is what is called once we have made sure that the middleware is applicable for a given endpoint. We will show in the next step how to tell if the middleware is applicable or not. For now though, we will just focus on what the middleware does once we know that it needs to be applied to a given endpoint. + +```scala mdoc:silent +def middleware( + authChecker: AuthChecker // 1 +): HttpApp[IO] => HttpApp[IO] = { inputApp => // 2 + HttpApp[IO] { request => // 3 + val maybeKey = request.headers // 4 + .get[`Authorization`] + .collect { + case Authorization( + Credentials.Token(AuthScheme.Bearer, value) + ) => + value + } + .map { ApiToken.apply } + + val isAuthorized = maybeKey + .map { key => + authChecker.isAuthorized(key) // 5 + } + .getOrElse(IO.pure(false)) + + isAuthorized.ifM( + ifTrue = inputApp(request), // 6 + ifFalse = IO.raiseError(new NotAuthorizedError("Not authorized!")) // 7 + ) + } +} +``` + +Let's break down what we did above step by step. The step numbers below correspond to the comment numbers above. + +1. Pass an instance of `AuthChecker` that we can use to verify auth tokens are valid in this middleware +2. `inputApp` is the `HttpApp[IO]` that we are transforming in this middleware. +3. Here we create a new HttpApp, the one that we will be returning from this function we are creating. +4. Here we extract the value of the `Authorization` header, if it is present. +5. If the header had a value, we now send that value into the `AuthChecker` to see if it is valid. +6. If the token was found to be valid, we pass the request into the `inputApp` from step 2 in order to get a response. +7. If the header was found to be invalid, we return the `NotAuthorizedError` that we defined in our smithy file above. + +#### ServerEndpointMiddleware.Simple + +Next, let's create our middleware by implementing the `ServerEndpointMiddleware.Simple` interface we discussed above. + +```scala mdoc:silent +object AuthMiddleware { + def apply( + authChecker: AuthChecker // 1 + ): ServerEndpointMiddleware[IO] = + new ServerEndpointMiddleware.Simple[IO] { + private val mid: HttpApp[IO] => HttpApp[IO] = middleware(authChecker) // 2 + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): HttpApp[IO] => HttpApp[IO] = { + serviceHints.get[smithy.api.HttpBearerAuth] match { // 3 + case Some(_) => + endpointHints.get[smithy.api.Auth] match { // 4 + case Some(auths) if auths.value.isEmpty => identity // 5 + case _ => mid // 6 + } + case None => identity + } + } + } +} +``` + +1. Pass in an instance of `AuthChecker` for the middleware to use. This is how the middleware will know if a given token is valid or not. +2. This is the function that we defined in the step above. +3. Check and see if the service at hand does in fact have the `httpBearerAuth` trait on it. If it doesn't, then we will not do our auth checks. If it does, then we will proceed. +4. Here we are getting the `@auth` trait from the operation (endpoint in smithy4s lingo). We need to check for this trait because of step 5. +5. Here we are checking that IF the auth trait is on this endpoint AND the auth trait contains an empty array THEN we are performing NO authentication checks. This is how we handle the `@auth([])` trait that is present on the `HealthCheck` operation we defined above. +6. IF the auth trait is NOT present on the operation, OR it is present AND it contains one or more authentication schemes, we apply the middleware. + +#### Using the Middleware + +From here, we can pass our middleware into our `SimpleRestJsonBuilder` as follows: + +```scala mdoc:silent +object HelloWorldAuthImpl extends HelloWorldAuthService[IO] { + def sayWorld(): IO[World] = World().pure[IO] + def healthCheck(): IO[HealthCheckOutput] = HealthCheckOutput("Okay!").pure[IO] +} + +val routes = SimpleRestJsonBuilder + .routes(HelloWorldAuthImpl) + .middleware(AuthMiddleware(AuthChecker)) + .resource +``` + +And that's it. Now we have a middleware that will apply an authentication check on incoming requests whenever relevant, as defined in our smithy file. + +## Client-side Middleware + +To see the **full code** example of what we walk through below, go [here](https://github.com/disneystreaming/smithy4s/tree/main/modules/guides/src/smithy4s/guides/AuthClient.scala). + +It is possible that you have a client where you want to apply a similar type of middleware that alters some part of a request depending on the endpoint being targeted. In this part of the guide, we will show how you can do this for a client using the same smithy specification we defined above. We will make it so our authentication token is only sent if we are targeting an endpoint which requires it. + +#### ClientEndpointMiddleware.Simple + +The interface that we define for this middleware is going to look very similar to the one we defined above. This makes sense because this middleware is effectively the dual of the middleware above. + +```scala mdoc:invisible +import org.http4s.client._ +import smithy4s.http4s.ClientEndpointMiddleware +``` + +```scala mdoc:silent +object Middleware { + + private def middleware(bearerToken: String): Client[IO] => Client[IO] = { // 1 + inputClient => + Client[IO] { request => + val newRequest = request.withHeaders( // 2 + Authorization(Credentials.Token(AuthScheme.Bearer, bearerToken)) + ) + + inputClient.run(newRequest) + } + } + + def apply(bearerToken: String): ClientEndpointMiddleware[IO] = // 3 + new ClientEndpointMiddleware.Simple[IO] { + private val mid = middleware(bearerToken) + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): Client[IO] => Client[IO] = { + serviceHints.get[smithy.api.HttpBearerAuth] match { + case Some(_) => + endpointHints.get[smithy.api.Auth] match { + case Some(auths) if auths.value.isEmpty => identity + case _ => mid + } + case None => identity + } + } + } + +} +``` + +1. Here we are creating an inner middleware function, just like we did above. The only differences are that this time we are adding a value to the request instead of extracting one from it and we are operating on `Client` instead of `HttpApp`. +2. Add the `Authorization` header to the request and pass it to the `inputClient` that we are transforming in this middleware. +3. This function is actually the *exact same* as the function for the middleware we implemented above. The only differences are that this apply method accepts a `bearerToken` as a parameter and returns a function on `Client` instead of `HttpApp`. The provided `bearerToken` is what we will add into the `Authorization` header when applicable. + +#### SimpleRestJsonBuilder + +As above, we now just need to wire our middleware into our actual implementation. Here we are constructing a client and specifying the middleware we just defined. + +```scala mdoc:silent +def apply(http4sClient: Client[IO]): Resource[IO, HelloWorldAuthService[IO]] = + SimpleRestJsonBuilder(HelloWorldAuthService) + .client(http4sClient) + .uri(Uri.unsafeFromString("http://localhost:9000")) + .middleware(Middleware("my-token")) // creating our middleware here + .resource +``` + +## Conclusion + +Once again, if you want to see the **full code** examples of the above, you can find them [here](https://github.com/disneystreaming/smithy4s/tree/main/modules/guides/src/smithy4s/guides/). + +Hopefully this guide gives you a good idea of how you can create a middleware that takes your smithy specification into account. This guide shows a very simple use case of what is possible with a middleware like this. If you have a more advanced use case, you can use this guide as a reference and as always you can reach out to us for insight or help. diff --git a/modules/guides/smithy/auth.smithy b/modules/guides/smithy/auth.smithy new file mode 100644 index 000000000..14b95fc3e --- /dev/null +++ b/modules/guides/smithy/auth.smithy @@ -0,0 +1,41 @@ +$version: "2" + +namespace smithy4s.guides.auth + +use alloy#simpleRestJson + +@simpleRestJson +@httpBearerAuth +service HelloWorldAuthService { + version: "1.0.0", + operations: [SayWorld, HealthCheck] + errors: [NotAuthorizedError] +} + + +@readonly +@http(method: "GET", uri: "/hello", code: 200) +operation SayWorld { + output: World +} + +@readonly +@http(method: "GET", uri: "/health", code: 200) +@auth([]) +operation HealthCheck { + output := { + @required + message: String + } +} + +structure World { + message: String = "World !" +} + +@error("client") +@httpError(401) +structure NotAuthorizedError { + @required + message: String +} diff --git a/modules/guides/src/smithy4s/guides/Auth.scala b/modules/guides/src/smithy4s/guides/Auth.scala new file mode 100644 index 000000000..1c85b3804 --- /dev/null +++ b/modules/guides/src/smithy4s/guides/Auth.scala @@ -0,0 +1,125 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s.guides + +import smithy4s.guides.auth._ +import cats.effect._ +import cats.implicits._ +import org.http4s.implicits._ +import org.http4s.ember.server._ +import org.http4s._ +import com.comcast.ip4s._ +import smithy4s.http4s.SimpleRestJsonBuilder +import smithy4s.Hints +import org.http4s.headers.Authorization +import smithy4s.http4s.ServerEndpointMiddleware + +final case class ApiToken(value: String) + +object HelloWorldAuthImpl extends HelloWorldAuthService[IO] { + def sayWorld(): IO[World] = World().pure[IO] + def healthCheck(): IO[HealthCheckOutput] = HealthCheckOutput("Okay!").pure[IO] +} + +trait AuthChecker { + def isAuthorized(token: ApiToken): IO[Boolean] +} + +object AuthChecker extends AuthChecker { + def isAuthorized(token: ApiToken): IO[Boolean] = { + IO.pure( + token.value.nonEmpty + ) // put your logic here, currently just makes sure the token is not empty + } +} + +object AuthExampleRoutes { + import org.http4s.server.middleware._ + + private val helloRoutes: Resource[IO, HttpRoutes[IO]] = + SimpleRestJsonBuilder + .routes(HelloWorldAuthImpl) + .middleware(AuthMiddleware(AuthChecker)) + .resource + + val all: Resource[IO, HttpRoutes[IO]] = + helloRoutes +} + +object AuthMiddleware { + + private def middleware( + authChecker: AuthChecker + ): HttpApp[IO] => HttpApp[IO] = { inputApp => + HttpApp[IO] { request => + val maybeKey = request.headers + .get[`Authorization`] + .collect { + case Authorization( + Credentials.Token(AuthScheme.Bearer, value) + ) => + value + } + .map { ApiToken.apply } + + val isAuthorized = maybeKey + .map { key => + authChecker.isAuthorized(key) + } + .getOrElse(IO.pure(false)) + + isAuthorized.ifM( + ifTrue = inputApp(request), + ifFalse = IO.raiseError(new NotAuthorizedError("Not authorized!")) + ) + } + } + + def apply( + authChecker: AuthChecker + ): ServerEndpointMiddleware[IO] = + new ServerEndpointMiddleware.Simple[IO] { + private val mid: HttpApp[IO] => HttpApp[IO] = middleware(authChecker) + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): HttpApp[IO] => HttpApp[IO] = { + serviceHints.get[smithy.api.HttpBearerAuth] match { + case Some(_) => + endpointHints.get[smithy.api.Auth] match { + case Some(auths) if auths.value.isEmpty => identity + case _ => mid + } + case None => identity + } + } + } +} + +// test with `curl localhost:9000/hello -H 'Authorization: Bearer Some'` +// or `curl localhost:9000/hello` +object AuthExampleMain extends IOApp.Simple { + val run = (for { + routes <- AuthExampleRoutes.all + server <- EmberServerBuilder + .default[IO] + .withPort(port"9000") + .withHost(host"localhost") + .withHttpApp(routes.orNotFound) + .build + } yield server).useForever +} diff --git a/modules/guides/src/smithy4s/guides/AuthClient.scala b/modules/guides/src/smithy4s/guides/AuthClient.scala new file mode 100644 index 000000000..4c245c872 --- /dev/null +++ b/modules/guides/src/smithy4s/guides/AuthClient.scala @@ -0,0 +1,80 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s.guides + +import smithy4s.guides.auth._ +import smithy4s.http4s._ +import cats.effect._ +import cats.implicits._ +import org.http4s.implicits._ +import org.http4s._ +import com.comcast.ip4s._ +import org.http4s.client._ +import org.http4s.ember.client.EmberClientBuilder +import smithy4s.Hints +import org.http4s.headers.Authorization + +object AuthClient { + def apply(http4sClient: Client[IO]): Resource[IO, HelloWorldAuthService[IO]] = + SimpleRestJsonBuilder(HelloWorldAuthService) + .client(http4sClient) + .uri(Uri.unsafeFromString("http://localhost:9000")) + .middleware(Middleware("my-token")) + .resource +} + +object Middleware { + + private def middleware(bearerToken: String): Client[IO] => Client[IO] = { + inputClient => + Client[IO] { request => + val newRequest = request.withHeaders( + Authorization(Credentials.Token(AuthScheme.Bearer, bearerToken)) + ) + + inputClient.run(newRequest) + } + } + + def apply(bearerToken: String): ClientEndpointMiddleware[IO] = + new ClientEndpointMiddleware.Simple[IO] { + private val mid = middleware(bearerToken) + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): Client[IO] => Client[IO] = { + serviceHints.get[smithy.api.HttpBearerAuth] match { + case Some(_) => + endpointHints.get[smithy.api.Auth] match { + case Some(auths) if auths.value.isEmpty => identity + case _ => mid + } + case None => identity + } + } + } + +} + +object AuthClientExampleMain extends IOApp.Simple { + val run = (for { + client <- EmberClientBuilder.default[IO].build + authClient <- AuthClient(client) + health <- Resource.eval(authClient.healthCheck().flatMap(IO.println)) + hello <- Resource.eval(authClient.sayWorld().flatMap(IO.println)) + } yield ()).use_ +} diff --git a/modules/http4s/src/smithy4s/http4s/ClientEndpointMiddleware.scala b/modules/http4s/src/smithy4s/http4s/ClientEndpointMiddleware.scala new file mode 100644 index 000000000..84f865662 --- /dev/null +++ b/modules/http4s/src/smithy4s/http4s/ClientEndpointMiddleware.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s +package http4s + +import org.http4s.client.Client + +// format: off +trait ClientEndpointMiddleware[F[_]] { + def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): Client[F] => Client[F] +} +// format: on + +object ClientEndpointMiddleware { + + trait Simple[F[_]] extends ClientEndpointMiddleware[F] { + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): Client[F] => Client[F] + + final def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): Client[F] => Client[F] = + prepareWithHints(service.hints, endpoint.hints) + } + + def noop[F[_]]: ClientEndpointMiddleware[F] = + new ClientEndpointMiddleware[F] { + override def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): Client[F] => Client[F] = identity + } + +} diff --git a/modules/http4s/src/smithy4s/http4s/ServerEndpointMiddleware.scala b/modules/http4s/src/smithy4s/http4s/ServerEndpointMiddleware.scala new file mode 100644 index 000000000..ca3557a2f --- /dev/null +++ b/modules/http4s/src/smithy4s/http4s/ServerEndpointMiddleware.scala @@ -0,0 +1,54 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s +package http4s + +import org.http4s.HttpApp + +// format: off +trait ServerEndpointMiddleware[F[_]] { + def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): HttpApp[F] => HttpApp[F] +} +// format: on + +object ServerEndpointMiddleware { + + trait Simple[F[_]] extends ServerEndpointMiddleware[F] { + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): HttpApp[F] => HttpApp[F] + + final def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): HttpApp[F] => HttpApp[F] = + prepareWithHints(service.hints, endpoint.hints) + } + + private[http4s] type EndpointMiddleware[F[_], Op[_, _, _, _, _]] = + Endpoint[Op, _, _, _, _, _] => HttpApp[F] => HttpApp[F] + + def noop[F[_]]: ServerEndpointMiddleware[F] = + new ServerEndpointMiddleware[F] { + override def prepare[Alg[_[_, _, _, _, _]]](service: Service[Alg])( + endpoint: Endpoint[service.Operation, _, _, _, _, _] + ): HttpApp[F] => HttpApp[F] = identity + } + +} diff --git a/modules/http4s/src/smithy4s/http4s/SimpleProtocolBuilder.scala b/modules/http4s/src/smithy4s/http4s/SimpleProtocolBuilder.scala index 6d4311dd9..184fac486 100644 --- a/modules/http4s/src/smithy4s/http4s/SimpleProtocolBuilder.scala +++ b/modules/http4s/src/smithy4s/http4s/SimpleProtocolBuilder.scala @@ -48,7 +48,8 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit new RouterBuilder[Alg, F]( service, impl, - PartialFunction.empty + PartialFunction.empty, + ServerEndpointMiddleware.noop[F] ) } @@ -65,7 +66,8 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit new RouterBuilder[Alg, F]( service, impl, - PartialFunction.empty + PartialFunction.empty, + ServerEndpointMiddleware.noop[F] ) } @@ -76,11 +78,17 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit ] private[http4s] ( client: Client[F], val service: smithy4s.Service[Alg], - uri: Uri = uri"http://localhost:8080" + uri: Uri = uri"http://localhost:8080", + middleware: ClientEndpointMiddleware[F] = ClientEndpointMiddleware.noop[F] ) { def uri(uri: Uri): ClientBuilder[Alg, F] = - new ClientBuilder[Alg, F](this.client, this.service, uri) + new ClientBuilder[Alg, F](this.client, this.service, uri, this.middleware) + + def middleware( + mid: ClientEndpointMiddleware[F] + ): ClientBuilder[Alg, F] = + new ClientBuilder[Alg, F](this.client, this.service, this.uri, mid) def resource: Resource[F, FunctorAlgebra[Alg, F]] = use.leftWiden[Throwable].liftTo[Resource[F, *]] @@ -95,7 +103,8 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit service, client, EntityCompiler - .fromCodecAPI[F](codecs) + .fromCodecAPI[F](codecs), + middleware ) } .map(service.fromPolyFunction[Kind1[F]#toKind5](_)) @@ -108,8 +117,11 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit ] private[http4s] ( service: smithy4s.Service[Alg], impl: FunctorAlgebra[Alg, F], - errorTransformation: PartialFunction[Throwable, F[Throwable]] - )(implicit F: EffectCompat[F]) { + errorTransformation: PartialFunction[Throwable, F[Throwable]], + middleware: ServerEndpointMiddleware[F] + )(implicit + F: EffectCompat[F] + ) { val entityCompiler = EntityCompiler.fromCodecAPI(codecs) @@ -117,12 +129,17 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit def mapErrors( fe: PartialFunction[Throwable, Throwable] ): RouterBuilder[Alg, F] = - new RouterBuilder(service, impl, fe andThen (e => F.pure(e))) + new RouterBuilder(service, impl, fe andThen (e => F.pure(e)), middleware) def flatMapErrors( fe: PartialFunction[Throwable, F[Throwable]] ): RouterBuilder[Alg, F] = - new RouterBuilder(service, impl, fe) + new RouterBuilder(service, impl, fe, middleware) + + def middleware( + mid: ServerEndpointMiddleware[F] + ): RouterBuilder[Alg, F] = + new RouterBuilder[Alg, F](service, impl, errorTransformation, mid) def make: Either[UnsupportedProtocolError, HttpRoutes[F]] = checkProtocol(service, protocolTag) @@ -133,7 +150,8 @@ abstract class SimpleProtocolBuilder[P](val codecs: CodecAPI)(implicit service, service.toPolyFunction[Kind1[F]#toKind5](impl), errorTransformation, - entityCompiler + entityCompiler, + middleware ).routes } diff --git a/modules/http4s/src/smithy4s/http4s/SmithyHttp4sReverseRouter.scala b/modules/http4s/src/smithy4s/http4s/SmithyHttp4sReverseRouter.scala index 36ab95b19..784ad135a 100644 --- a/modules/http4s/src/smithy4s/http4s/SmithyHttp4sReverseRouter.scala +++ b/modules/http4s/src/smithy4s/http4s/SmithyHttp4sReverseRouter.scala @@ -27,7 +27,8 @@ class SmithyHttp4sReverseRouter[Alg[_[_, _, _, _, _]], Op[_, _, _, _, _], F[_]]( baseUri: Uri, service: smithy4s.Service.Aux[Alg, Op], client: Client[F], - entityCompiler: EntityCompiler[F] + entityCompiler: EntityCompiler[F], + middleware: ClientEndpointMiddleware[F] )(implicit effect: EffectCompat[F]) extends FunctorInterpreter[Op, F] { // format: on @@ -55,7 +56,8 @@ class SmithyHttp4sReverseRouter[Alg[_[_, _, _, _, _]], Op[_, _, _, _, _], F[_]]( baseUri, client, endpoint, - compilerContext + compilerContext, + middleware.prepare(service)(endpoint) ) .left .map { e => diff --git a/modules/http4s/src/smithy4s/http4s/SmithyHttp4sRouter.scala b/modules/http4s/src/smithy4s/http4s/SmithyHttp4sRouter.scala index c42baae2a..47cc0ec9e 100644 --- a/modules/http4s/src/smithy4s/http4s/SmithyHttp4sRouter.scala +++ b/modules/http4s/src/smithy4s/http4s/SmithyHttp4sRouter.scala @@ -23,15 +23,21 @@ import cats.implicits._ import org.http4s._ import smithy4s.http4s.internals.SmithyHttp4sServerEndpoint import smithy4s.kinds._ +import org.typelevel.vault.Key +import cats.effect.SyncIO // format: off class SmithyHttp4sRouter[Alg[_[_, _, _, _, _]], Op[_, _, _, _, _], F[_]]( service: smithy4s.Service.Aux[Alg, Op], impl: FunctorInterpreter[Op, F], errorTransformation: PartialFunction[Throwable, F[Throwable]], - entityCompiler: EntityCompiler[F] + entityCompiler: EntityCompiler[F], + middleware: ServerEndpointMiddleware[F] )(implicit effect: EffectCompat[F]) { + private val pathParamsKey = + Key.newKey[SyncIO, smithy4s.http.PathParams].unsafeRunSync() + private val compilerContext = internals.CompilerContext.make(entityCompiler) val routes: HttpRoutes[F] = Kleisli { request => @@ -39,7 +45,7 @@ class SmithyHttp4sRouter[Alg[_[_, _, _, _, _]], Op[_, _, _, _, _], F[_]]( endpoints <- perMethodEndpoint.get(request.method).toOptionT[F] path = request.uri.path.segments.map(_.decoded()).toArray (endpoint, pathParams) <- endpoints.collectFirstSome(_.matchTap(path)).toOptionT[F] - response <- OptionT.liftF(endpoint.run(pathParams, request)) + response <- OptionT.liftF(endpoint.httpApp(request.withAttribute(pathParamsKey, pathParams))) } yield response } // format: on @@ -51,7 +57,9 @@ class SmithyHttp4sRouter[Alg[_[_, _, _, _, _]], Op[_, _, _, _, _], F[_]]( impl, ep, compilerContext, - errorTransformation + errorTransformation, + middleware.prepare(service) _, + pathParamsKey ) } .collect { case Right(http4sEndpoint) => diff --git a/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sClientEndpoint.scala b/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sClientEndpoint.scala index 2097b074f..a2fea9599 100644 --- a/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sClientEndpoint.scala +++ b/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sClientEndpoint.scala @@ -46,7 +46,8 @@ private[http4s] object SmithyHttp4sClientEndpoint { baseUri: Uri, client: Client[F], endpoint: Endpoint[Op, I, E, O, SI, SO], - compilerContext: CompilerContext[F] + compilerContext: CompilerContext[F], + middleware: Client[F] => Client[F] ): Either[ HttpEndpoint.HttpEndpointError, SmithyHttp4sClientEndpoint[F, Op, I, E, O, SI, SO] @@ -65,7 +66,8 @@ private[http4s] object SmithyHttp4sClientEndpoint { method, endpoint, httpEndpoint, - compilerContext + compilerContext, + middleware ) } } @@ -79,12 +81,15 @@ private[http4s] class SmithyHttp4sClientEndpointImpl[F[_], Op[_, _, _, _, _], I, method: org.http4s.Method, endpoint: Endpoint[Op, I, E, O, SI, SO], httpEndpoint: HttpEndpoint[I], - compilerContext: CompilerContext[F] + compilerContext: CompilerContext[F], + middleware: Client[F] => Client[F] )(implicit effect: EffectCompat[F]) extends SmithyHttp4sClientEndpoint[F, Op, I, E, O, SI, SO] { // format: on + private val transformedClient: Client[F] = middleware(client) + def send(input: I): F[O] = { - client + transformedClient .run(inputToRequest(input)) .use { response => outputFromResponse(response) diff --git a/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sServerEndpoint.scala b/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sServerEndpoint.scala index 0333a758c..a49892dfb 100644 --- a/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sServerEndpoint.scala +++ b/modules/http4s/src/smithy4s/http4s/internals/SmithyHttp4sServerEndpoint.scala @@ -31,6 +31,8 @@ import smithy4s.http.Metadata import smithy4s.http._ import smithy4s.schema.Alt import smithy4s.kinds._ +import org.http4s.HttpApp +import org.typelevel.vault.Key /** * A construct that encapsulates a smithy4s endpoint, and exposes @@ -39,7 +41,7 @@ import smithy4s.kinds._ private[http4s] trait SmithyHttp4sServerEndpoint[F[_]] { def method: org.http4s.Method def matches(path: Array[String]): Option[PathParams] - def run(pathParams: PathParams, request: Request[F]): F[Response[F]] + def httpApp: HttpApp[F] def matchTap( path: Array[String] @@ -49,15 +51,19 @@ private[http4s] trait SmithyHttp4sServerEndpoint[F[_]] { private[http4s] object SmithyHttp4sServerEndpoint { + // format: off def make[F[_]: EffectCompat, Op[_, _, _, _, _], I, E, O, SI, SO]( impl: FunctorInterpreter[Op, F], endpoint: Endpoint[Op, I, E, O, SI, SO], compilerContext: CompilerContext[F], - errorTransformation: PartialFunction[Throwable, F[Throwable]] + errorTransformation: PartialFunction[Throwable, F[Throwable]], + middleware: ServerEndpointMiddleware.EndpointMiddleware[F, Op], + pathParamsKey: Key[PathParams] ): Either[ HttpEndpoint.HttpEndpointError, SmithyHttp4sServerEndpoint[F] ] = + // format: on HttpEndpoint.cast(endpoint).flatMap { httpEndpoint => toHttp4sMethod(httpEndpoint.method) .leftMap { e => @@ -72,7 +78,9 @@ private[http4s] object SmithyHttp4sServerEndpoint { method, httpEndpoint, compilerContext, - errorTransformation + errorTransformation, + middleware, + pathParamsKey ) } } @@ -87,6 +95,8 @@ private[http4s] class SmithyHttp4sServerEndpointImpl[F[_], Op[_, _, _, _, _], I, httpEndpoint: HttpEndpoint[I], compilerContext: CompilerContext[F], errorTransformation: PartialFunction[Throwable, F[Throwable]], + middleware: ServerEndpointMiddleware.EndpointMiddleware[F, Op], + pathParamsKey: Key[PathParams] )(implicit F: EffectCompat[F]) extends SmithyHttp4sServerEndpoint[F] { // format: on import compilerContext._ @@ -97,18 +107,22 @@ private[http4s] class SmithyHttp4sServerEndpointImpl[F[_], Op[_, _, _, _, _], I, httpEndpoint.matches(path) } - def run(pathParams: PathParams, request: Request[F]): F[Response[F]] = { - val run: F[O] = for { - metadata <- getMetadata(pathParams, request) - input <- extractInput(metadata, request) - output <- (impl(endpoint.wrap(input)): F[O]) - } yield output + private val applyMiddleware = middleware(endpoint) - run.recoverWith(transformError).attempt.flatMap { - case Left(error) => errorResponse(error) - case Right(output) => successResponse(output) - } - } + override val httpApp: HttpApp[F] = + applyMiddleware(HttpApp[F] { req => + val pathParams = req.attributes.lookup(pathParamsKey).getOrElse(Map.empty) + + val run: F[O] = for { + metadata <- getMetadata(pathParams, req) + input <- extractInput(metadata, req) + output <- (impl(endpoint.wrap(input)): F[O]) + } yield output + + run + .recoverWith(transformError) + .flatMap(successResponse) + }).handleErrorWith(error => Kleisli.liftF(errorResponse(error))) private val inputSchema: Schema[I] = endpoint.input private val outputSchema: Schema[O] = endpoint.output diff --git a/modules/http4s/test/src/smithy4s/http4s/EndpointSpecificMiddlewareSpec.scala b/modules/http4s/test/src/smithy4s/http4s/EndpointSpecificMiddlewareSpec.scala new file mode 100644 index 000000000..d9f2e6955 --- /dev/null +++ b/modules/http4s/test/src/smithy4s/http4s/EndpointSpecificMiddlewareSpec.scala @@ -0,0 +1,193 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s +package http4s + +import weaver._ +import smithy4s.hello._ +import org.http4s.HttpApp +import cats.effect.IO +import cats.data.OptionT +import org.http4s.Uri +import org.http4s._ +import fs2.Collector +import org.http4s.client.Client +import cats.Eq +import cats.effect.Resource + +object ServerEndpointMiddlewareSpec extends SimpleIOSuite { + + private implicit val greetingEq: Eq[Greeting] = Eq.fromUniversalEquals + private implicit val throwableEq: Eq[Throwable] = Eq.fromUniversalEquals + + test("server - middleware is applied") { + serverMiddlewareTest( + shouldFailInMiddleware = true, + Request[IO](Method.POST, Uri.unsafeFromString("/bob")), + response => + IO.pure(expect.eql(response.status, Status.InternalServerError)) + ) + } + + test( + "server - middleware allows passing through to underlying implementation" + ) { + serverMiddlewareTest( + shouldFailInMiddleware = false, + Request[IO](Method.POST, Uri.unsafeFromString("/bob")), + response => { + response.body.compile + .to(Collector.supportsArray(Array)) + .map(new String(_)) + .map { body => + expect.eql(response.status, Status.Ok) && + expect.eql(body, """{"message":"Hello, bob"}""") + } + } + ) + } + + test("client - middleware is applied") { + clientMiddlewareTest( + shouldFailInMiddleware = true, + service => + service.hello("bob").attempt.map { result => + expect.eql(result, Left(new GenericServerError(Some("failed")))) + } + ) + } + + test("client - send request through middleware") { + clientMiddlewareTest( + shouldFailInMiddleware = false, + service => + service.hello("bob").attempt.map { result => + expect.eql(result, Right(Greeting("Hello, bob"))) + } + ) + } + + private def serverMiddlewareTest( + shouldFailInMiddleware: Boolean, + request: Request[IO], + expect: Response[IO] => IO[Expectations] + )(implicit pos: SourceLocation): IO[Expectations] = { + val service = + SimpleRestJsonBuilder + .routes(HelloImpl) + .middleware( + new TestServerMiddleware(shouldFail = shouldFailInMiddleware) + ) + .make + .toOption + .get + + service(request) + .flatMap(res => OptionT.liftF(expect(res))) + .getOrElse( + failure("unable to run request") + ) + } + + private def clientMiddlewareTest( + shouldFailInMiddleware: Boolean, + expect: HelloWorldService[IO] => IO[Expectations] + ): IO[Expectations] = { + val serviceNoMiddleware: HttpApp[IO] = + SimpleRestJsonBuilder + .routes(HelloImpl) + .make + .toOption + .get + .orNotFound + + val client: HelloWorldService[IO] = { + val http4sClient = Client.fromHttpApp(serviceNoMiddleware) + SimpleRestJsonBuilder(HelloWorldService) + .client(http4sClient) + .middleware( + new TestClientMiddleware(shouldFail = shouldFailInMiddleware) + ) + .use + .toOption + .get + } + + expect(client) + } + + private object HelloImpl extends HelloWorldService[IO] { + def hello(name: String, town: Option[String]): IO[Greeting] = IO.pure( + Greeting(s"Hello, $name") + ) + } + + private final class TestServerMiddleware(shouldFail: Boolean) + extends ServerEndpointMiddleware.Simple[IO] { + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): HttpApp[IO] => HttpApp[IO] = { inputApp => + HttpApp[IO] { request => + val hasTag: (Hints, String) => Boolean = (hints, tagName) => + hints.get[smithy.api.Tags].exists(_.value.contains(tagName)) + // check for tags in hints to test that proper hints are sent into the prepare method + if ( + hasTag(serviceHints, "testServiceTag") && + hasTag(endpointHints, "testOperationTag") + ) { + if (shouldFail) { + IO.raiseError(new GenericServerError(Some("failed"))) + } else { + inputApp(request) + } + } else { + IO.raiseError(new Exception("didn't find tags in hints")) + } + } + } + } + + private final class TestClientMiddleware(shouldFail: Boolean) + extends ClientEndpointMiddleware.Simple[IO] { + def prepareWithHints( + serviceHints: Hints, + endpointHints: Hints + ): Client[IO] => Client[IO] = { inputClient => + Client[IO] { request => + val hasTag: (Hints, String) => Boolean = (hints, tagName) => + hints.get[smithy.api.Tags].exists(_.value.contains(tagName)) + // check for tags in hints to test that proper hints are sent into the prepare method + if ( + hasTag(serviceHints, "testServiceTag") && + hasTag(endpointHints, "testOperationTag") + ) { + if (shouldFail) { + Resource.eval(IO.raiseError(new GenericServerError(Some("failed")))) + } else { + inputClient.run(request) + } + } else { + Resource.eval( + IO.raiseError(new Exception("didn't find tags in hints")) + ) + } + } + } + } + +} diff --git a/sampleSpecs/hello.smithy b/sampleSpecs/hello.smithy index c3115c3b8..8a47e2ed2 100644 --- a/sampleSpecs/hello.smithy +++ b/sampleSpecs/hello.smithy @@ -3,6 +3,7 @@ namespace smithy4s.hello use alloy#simpleRestJson @simpleRestJson +@tags(["testServiceTag"]) service HelloWorldService { version: "1.0.0", // Indicates that all operations in `HelloWorldService`, @@ -18,6 +19,7 @@ structure GenericServerError { } @http(method: "POST", uri: "/{name}", code: 200) +@tags(["testOperationTag"]) operation Hello { input: Person, output: Greeting