Skip to content

Commit

Permalink
chore: switch to msi based tokens (#2221)
Browse files Browse the repository at this point in the history
* chore: switch to msi based tokens

* chore: fix style

* chore: add docker msi token

* fix blob upload to use msi

* wip

* wip

* continue removing code
  • Loading branch information
mhamilton723 authored May 3, 2024
1 parent a72b94b commit 7b54e89
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package com.microsoft.azure.synapse.ml.services.anomaly

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.azure.synapse.ml.nbtest.SynapseUtilities.getAccessToken
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -88,9 +88,7 @@ class DetectLastAnomalySuite extends TransformerFuzzing[DetectLastAnomaly] with
}

test("Basic usage with AAD auth") {
val aadToken = getAccessToken(Secrets.ServicePrincipalClientId,
Secrets.ServiceConnectionSecret,
"https://cognitiveservices.azure.com/")
val aadToken = getAccessToken("https://cognitiveservices.azure.com/")
val ad = new DetectLastAnomaly()
.setAADToken(aadToken)
.setCustomServiceName("synapseml-ad-custom")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.azure.synapse.ml.nbtest.SynapseUtilities.getAccessToken
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.scalactic.Equality
Expand All @@ -26,6 +26,12 @@ class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with Op

import spark.implicits._

override def beforeAll(): Unit = {
val aadToken = getAccessToken("https://cognitiveservices.azure.com/")
println(s"Triggering token creation early ${aadToken.length}")
super.beforeAll()
}

def newCompletion: OpenAICompletion = new OpenAICompletion()
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
Expand Down Expand Up @@ -60,10 +66,7 @@ class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with Op
}

test("Basic usage with AAD auth") {
val aadToken = getAccessToken(
Secrets.ServicePrincipalClientId,
Secrets.ServiceConnectionSecret,
"https://cognitiveservices.azure.com/")
val aadToken = getAccessToken("https://cognitiveservices.azure.com/")

val completion = new OpenAICompletion()
.setAADToken(aadToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.ml.util.MLReadable
Expand All @@ -14,6 +15,12 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

import spark.implicits._

override def beforeAll(): Unit = {
val aadToken = getAccessToken("https://cognitiveservices.azure.com/")
println(s"Triggering token creation early ${aadToken.length}")
super.beforeAll()
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ object Secrets {
secretJson.parseJson.asJsObject().fields("value").convertTo[String]
}

def getAccessToken(reqResource: String): String = {
println(s"[info] token for perms: $reqResource from $AccountString")
val json = exec(s"az account get-access-token --resource $reqResource --output json")
json.parseJson.asJsObject().fields("accessToken").convertTo[String]
}

lazy val CognitiveApiKey: String = getSecret("cognitive-api-key")
lazy val OpenAIApiKey: String = getSecret("openai-api-key")
lazy val OpenAIApiKeyGpt4: String = getSecret("openai-api-key-2")
Expand All @@ -68,7 +74,5 @@ object Secrets {
lazy val ArtifactStore: String = getSecret("synapse-artifact-store")
lazy val Platform: String = getSecret("synapse-platform")
lazy val AadResource: String = getSecret("synapse-internal-aad-resource")
lazy val ServiceConnectionSecret: String = getSecret("service-connection-secret")
lazy val ServicePrincipalClientId: String = getSecret("service-principal-clientId")

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.microsoft.azure.synapse.ml.nbtest

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.Secrets.getAccessToken
import com.microsoft.azure.synapse.ml.build.BuildInfo
import com.microsoft.azure.synapse.ml.core.env.PackageUtils.{SparkMavenPackageList, SparkMavenRepositoryList}
import com.microsoft.azure.synapse.ml.io.http.RESTHelpers
Expand Down Expand Up @@ -118,18 +119,14 @@ object SynapseUtilities {

import SynapseJsonProtocol._

lazy val SynapseToken: String = getAccessToken(ClientId, Secrets.ServiceConnectionSecret,
"https://dev.azuresynapse.net/")
lazy val ArmToken: String = getAccessToken(ClientId, Secrets.ServiceConnectionSecret,
"https://management.azure.com/")
lazy val SynapseToken: String = getAccessToken("https://dev.azuresynapse.net/")
lazy val ArmToken: String = getAccessToken("https://management.azure.com/")

val LineSeparator: String = sys.props("line.separator").toLowerCase // Platform agnostic (\r\n:windows, \n:linux)
val Folder = s"build_${BuildInfo.version}/scripts"
val TimeoutInMillis: Int = 30 * 60 * 1000 // 30 minutes
val StorageAccount: String = "mmlsparkbuildsynapse"
val StorageContainer: String = "synapse"
val TenantId: String = "72f988bf-86f1-41af-91ab-2d7cd011db47"
val ClientId: String = Secrets.ServicePrincipalClientId
val PoolNodeSize: String = "Small"
val PoolLocation: String = "eastus2"
val WorkspaceName: String = "mmlsparkbuild"
Expand Down Expand Up @@ -176,8 +173,11 @@ object SynapseUtilities {
def uploadAndSubmitNotebook(poolName: String, notebook: File): LivyBatch = {
val dest = s"$Folder/${notebook.getName}"
exec(s"az storage fs file upload " +
s" -s ${notebook.getAbsolutePath} -p $dest -f $StorageContainer " +
" --overwrite true " +
s" -s ${notebook.getAbsolutePath}" +
s" -p $dest" +
s" -f $StorageContainer" +
s" --auth-mode login" +
s" --overwrite true" +
s" --account-name $StorageAccount")
val abfssPath = s"abfss://$StorageContainer@$StorageAccount.dfs.core.windows.net/$dest"

Expand Down Expand Up @@ -317,19 +317,4 @@ object SynapseUtilities {
safeSend(deleteRequest)
}

def getAccessToken(clientId: String, clientSecret: String, reqResource: String): String = {
val createRequest = new HttpPost(s"https://login.microsoftonline.com/$TenantId/oauth2/token")
createRequest.setHeader("Content-Type", "application/x-www-form-urlencoded")
createRequest.setEntity(
new UrlEncodedFormEntity(
List(
("grant_type", "client_credentials"),
("client_id", clientId),
("client_secret", clientSecret),
("resource", reqResource)
).map(p => new BasicNameValuePair(p._1, p._2)).asJava, "UTF-8")
)
RESTHelpers.sendAndParseJson(createRequest).asJsObject()
.fields("access_token").convertTo[String]
}
}
12 changes: 6 additions & 6 deletions pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ jobs:
- task: Docker@2
displayName: Demo Image Build
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/build-demo'
command: 'build'
buildContext: "."
Expand All @@ -210,14 +210,14 @@ jobs:
- task: Docker@2
displayName: Demo Image Push
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/build-demo'
command: 'push'
tags: $(version)
- task: Docker@2
displayName: Minimal Image Build
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/build-minimal'
command: 'build'
buildContext: "."
Expand All @@ -227,15 +227,15 @@ jobs:
- task: Docker@2
displayName: Minimal Image Push
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/build-minimal'
command: 'push'
tags: $(version)
- task: Docker@2
condition: and(eq(variables.isMaster, true), startsWith(variables['gittag'], 'v'))
displayName: Release Image Build
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/release'
command: 'build'
buildContext: "."
Expand All @@ -248,7 +248,7 @@ jobs:
condition: and(eq(variables.isMaster, true), startsWith(variables['gittag'], 'v'))
displayName: Release Image Push
inputs:
containerRegistry: 'SynapseML MCR'
containerRegistry: 'SynapseML MCR MSI'
repository: 'public/mmlspark/release'
command: 'push'
tags: |
Expand Down
2 changes: 1 addition & 1 deletion templates/update_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ steps:
versionSpec: '8'
jdkArchitectureOption: 'x64'
jdkSourceOption: 'PreInstalled'
- bash: pip install azure-cli==2.58.0
- bash: pip install azure-cli==2.60.0
displayName: 'Upgrade Azure CLI'

0 comments on commit 7b54e89

Please sign in to comment.