Skip to content

Commit

Permalink
feat: add aad authentication support for cognitive services (#1778)
Browse files Browse the repository at this point in the history
* feat: add aad support

* set aadtoken if in synapse internal env

* remove set aadtoken for now

* rename trait

* remove warnings

* add aad token and endpoint for synapse internal

* rename aad to AAD

* add fuzzing test

* add aad auth test case

* format

* add header for internal usage
  • Loading branch information
serena-ruan authored Jan 11, 2023
1 parent dd1563f commit dc57dea
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import spray.json.DefaultJsonProtocol._

import java.net.URI
import java.util.UUID
import scala.collection.JavaConverters._
import scala.language.existentials
import scala.reflect.internal.util.ScalaClassLoader
Expand Down Expand Up @@ -132,14 +133,86 @@ trait HasSubscriptionKey extends HasServiceParams {

def getSubscriptionKey: String = getScalarParam(subscriptionKey)

def setSubscriptionKey(v: String): this.type = setScalarParam(subscriptionKey, v)
def setSubscriptionKey(v: String): this.type = {
setScalarParam(subscriptionKey, v)
}

def getSubscriptionKeyCol: String = getVectorParam(subscriptionKey)

def setSubscriptionKeyCol(v: String): this.type = setVectorParam(subscriptionKey, v)

}

trait HasAADToken extends HasServiceParams {
// scalastyle:off field.name
val AADToken = new ServiceParam[String](
this, "AADToken", "AAD Token used for authentication"
)
// scalastyle:on field.name

def setAADToken(v: String): this.type = {
setScalarParam(AADToken, v)
}

def getAADToken: String = getScalarParam(AADToken)

def setAADTokenCol(v: String): this.type = setVectorParam(AADToken, v)

def getAADTokenCol: String = getVectorParam(AADToken)
}

trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.cognitiveservices.azure.com/" + urlPath.stripPrefix("/"))
}

def setEndpoint(v: String): this.type = {
setUrl(v + urlPath.stripPrefix("/"))
}

override def pyAdditionalMethods: String = super.pyAdditionalMethods + {
"""def setCustomServiceName(self, value):
| self._java_obj = self._java_obj.setCustomServiceName(value)
| return self
|
|def setEndpoint(self, value):
| self._java_obj = self._java_obj.setEndpoint(value)
| return self
|
|def _transform(self, dataset: DataFrame) -> DataFrame:
| if running_on_synapse_internal():
| from synapse.ml.mlflow import get_mlflow_env_config
| mlflow_env_configs = get_mlflow_env_config()
| self.setAADToken(mlflow_env_configs.driver_aad_token)
| self.setEndpoint(mlflow_env_configs.workload_endpoint + "/cognitive/api/")
| return super()._transform(dataset)
|""".stripMargin
}

override def dotnetAdditionalMethods: String = super.dotnetAdditionalMethods + {
s"""/// <summary>
|/// Sets value for service name
|/// </summary>
|/// <param name=\"value\">
|/// Service name of the cognitive service if it's custom domain
|/// </param>
|/// <returns> New $dotnetClassName object </returns>
|public $dotnetClassName SetCustomServiceName(string value) =>
| $dotnetClassWrapperName(Reference.Invoke(\"setCustomServiceName\", value));
|
|/// <summary>
|/// Sets value for endpoint
|/// </summary>
|/// <param name=\"value\">
|/// Endpoint of the cognitive service
|/// </param>
|/// <returns> New $dotnetClassName object </returns>
|public $dotnetClassName SetEndpoint(string value) =>
| $dotnetClassWrapperName(Reference.Invoke(\"setEndpoint\", value));
|""".stripMargin
}
}

object URLEncodingUtils {

private case class NameValuePairInternal(t: (String, String)) extends NameValuePair {
Expand All @@ -153,7 +226,7 @@ object URLEncodingUtils {
}
}

trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey {
trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken {

protected def paramNameToPayloadName(p: Param[_]): String = p match {
case p: ServiceParam[_] => p.payloadName
Expand Down Expand Up @@ -186,6 +259,8 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey {

protected val subscriptionKeyHeaderName = "Ocp-Apim-Subscription-Key"

protected val aadHeaderName = "Authorization"

protected def contentType: Row => String = { _ => "application/json" }

protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = {
Expand All @@ -197,8 +272,17 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey {
} else {
val req = prepareMethod()
req.setURI(new URI(rowToUrl(row)))
getValueOpt(row, subscriptionKey).foreach(
req.setHeader(subscriptionKeyHeaderName, _))
if (getValueOpt(row, subscriptionKey).nonEmpty) {
req.setHeader(subscriptionKeyHeaderName, getValue(row, subscriptionKey))
} else {
getValueOpt(row, AADToken).foreach(s =>
{
req.setHeader(aadHeaderName, "Bearer " + s)
// this is required for internal workload
req.setHeader("x-ms-workload-resource-moniker", UUID.randomUUID().toString)
}
)
}
req.setHeader("Content-Type", contentType(row))

req match {
Expand Down Expand Up @@ -322,7 +406,9 @@ trait DomainHelper {
abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transformer
with ConcurrencyParams with HasOutputCol
with HasURL with ComplexParamsWritable
with HasSubscriptionKey with HasErrorCol with SynapseMLLogging {
with HasSubscriptionKey with HasErrorCol
with HasAADToken with HasCustomCogServiceDomain
with SynapseMLLogging {

setDefault(
outputCol -> (this.uid + "_output"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(u

def this() = this(Identifiable.randomUID("OpenAPICompletion"))

def urlPath: String = ""

override protected def prepareUrlRoot: Row => String = { row =>
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/completions"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)

def this() = this(Identifiable.randomUID("AddDocuments"))

def urlPath: String = ""

setDefault(actionCol -> "@search.action")

override val subscriptionKeyHeaderName = "api-key"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class AddressGeocoder(override val uid: String)
with HasInternalJsonOutputParser with MapsAsyncReply with SynapseMLLogging {
logClass()

def urlPath: String = ""

protected def inputFunc: Row => Option[HttpRequestBase] = {
{ row: Row =>
if (shouldSkip(row)) {
Expand Down Expand Up @@ -79,6 +81,8 @@ class ReverseAddressGeocoder(override val uid: String)
with HasSubscriptionKey with HasURL with HasLatLonPairInput {
logClass()

def urlPath: String = ""

protected def inputFunc: Row => Option[HttpRequestBase] = {
{ row: Row =>
if (shouldSkip(row)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.cognitive.anomaly
import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.azure.synapse.ml.nbtest.SynapseUtilities.getAccessToken
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -86,6 +87,25 @@ class DetectLastAnomalySuite extends TransformerFuzzing[DetectLastAnomaly] with
assert(result.isAnomaly)
}

test("Basic usage with AAD auth") {
val aadToken = getAccessToken(Secrets.ServicePrincipalClientId,
Secrets.ServiceConnectionSecret,
"https://cognitiveservices.azure.com/")
val ad = new DetectLastAnomaly()
.setAADToken(aadToken)
.setCustomServiceName("synapseml-ad-custom")
.setOutputCol("anomalies")
.setSeriesCol("inputs")
.setGranularity("monthly")
.setErrorCol("errors")
val fromRow = ADLastResponse.makeFromRowConverter
val result = fromRow(ad.transform(df)
.select("anomalies")
.collect()
.head.getStruct(0))
assert(result.isAnomaly)
}

test("minutely Usage") {
val fromRow = ADLastResponse.makeFromRowConverter
val result = fromRow(ad.setGranularity("minutely").transform(df2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ trait PythonWrappable extends BaseWrappable {
|from pyspark.ml.param.shared import *
|from pyspark import keyword_only
|from pyspark.ml.util import JavaMLReadable, JavaMLWritable
|from synapse.ml.core.platform import running_on_synapse_internal
|from synapse.ml.core.serialize.java_params_patch import *
|from pyspark.ml.wrapper import JavaTransformer, JavaEstimator, JavaModel
|from pyspark.ml.evaluation import JavaEvaluator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ object Secrets {
lazy val SynapseExtensionUxHost: String = getSecret("synapse-extension-dxt-ux-host")
lazy val SynapseExtensionSspHost: String = getSecret("synapse-extension-dxt-ssp-host")
lazy val SynapseExtensionWorkspaceId: String = getSecret("synapse-extension-dxt-workspace-id")
lazy val ServiceConnectionSecret: String = getSecret("service-connection-secret")
lazy val ServicePrincipalClientId: String = getSecret("service-principal-clientId")

lazy val SecretRegexpFile: String = getSecret("secret-regexp-file")
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ object SynapseUtilities {

import SynapseJsonProtocol._

lazy val SynapseToken: String = getAccessToken("https://dev.azuresynapse.net/")
lazy val ArmToken: String = getAccessToken("https://management.azure.com/")
lazy val SynapseToken: String = getAccessToken(ClientId, Secrets.SynapseSpnKey,
"https://dev.azuresynapse.net/")
lazy val ArmToken: String = getAccessToken(ClientId, Secrets.SynapseSpnKey,
"https://management.azure.com/")

val LineSeparator: String = sys.props("line.separator").toLowerCase // Platform agnostic (\r\n:windows, \n:linux)
val Folder = s"build_${BuildInfo.version}/scripts"
Expand Down Expand Up @@ -315,15 +317,15 @@ object SynapseUtilities {
safeSend(deleteRequest)
}

def getAccessToken(reqResource: String): String = {
def getAccessToken(clientId: String, clientSecret: String, reqResource: String): String = {
val createRequest = new HttpPost(s"https://login.microsoftonline.com/$TenantId/oauth2/token")
createRequest.setHeader("Content-Type", "application/x-www-form-urlencoded")
createRequest.setEntity(
new UrlEncodedFormEntity(
List(
("grant_type", "client_credentials"),
("client_id", s"$ClientId"),
("client_secret", s"${Secrets.SynapseSpnKey}"),
("client_id", clientId),
("client_secret", clientSecret),
("resource", reqResource)
).map(p => new BasicNameValuePair(p._1, p._2)).asJava, "UTF-8")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.core.test.fuzzing

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.build.BuildInfo
import com.microsoft.azure.synapse.ml.cognitive.{HasAADToken, HasSubscriptionKey}
import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasInputCol, HasLabelCol, HasOutputCol}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
Expand Down Expand Up @@ -348,6 +349,27 @@ class FuzzingTest extends TestBase {
}
}

test("Verify all classes extending HasSubscriptionKey also extend HasAADToken") {
val exemptions = Set[String](
// MVAD doesn't support aad token for now
"com.microsoft.azure.synapse.ml.cognitive.anomaly.DetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.cognitive.anomaly.FitMultivariateAnomaly",
// TO BE VERIFIED
"com.microsoft.azure.synapse.ml.cognitive.speech.ConversationTranscription",
"com.microsoft.azure.synapse.ml.cognitive.speech.SpeechToTextSDK",
"com.microsoft.azure.synapse.ml.cognitive.speech.TextToSpeech"
)
val subClazz = classOf[HasSubscriptionKey]
val clazz = classOf[HasAADToken]

pipelineStages.foreach { stage =>
if (!exemptions(stage.getClass.getName) && subClazz.isAssignableFrom(stage.getClass)) {
assertOrLog(clazz.isAssignableFrom(stage.getClass),
stage.getClass.getName + " needs to extend " + clazz.getName)
}
}
}

test("Scan codebase for secrets") {
val excludedFiles = List(
".png",
Expand Down

0 comments on commit dc57dea

Please sign in to comment.