Skip to content

Commit

Permalink
Handle error when pwsh does not exist language agnostic (#41910)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Feb 13, 2024
1 parent 6a9de22 commit 925ea6c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Breaking Changes

### Bugs Fixed
- `AzurePowerShellCredential` now handles the case where it falls back to legacy powershell without relying on the error message string.

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private async ValueTask<AccessToken> RequestAzurePowerShellAccessTokenAsync(bool
try
{
output = async ? await processRunner.RunAsync().ConfigureAwait(false) : processRunner.Run();
CheckForErrors(output);
CheckForErrors(output, processRunner.ExitCode);
ValidateResult(output);
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
Expand All @@ -168,7 +168,7 @@ private async ValueTask<AccessToken> RequestAzurePowerShellAccessTokenAsync(bool
}
catch (InvalidOperationException exception)
{
CheckForErrors(exception.Message);
CheckForErrors(exception.Message, processRunner.ExitCode);
if (_isChainedCredential)
{
throw new CredentialUnavailableException($"{AzurePowerShellFailedError} {exception.Message}");
Expand All @@ -181,9 +181,10 @@ private async ValueTask<AccessToken> RequestAzurePowerShellAccessTokenAsync(bool
return DeserializeOutput(output);
}

private static void CheckForErrors(string output)
private static void CheckForErrors(string output, int exitCode)
{
bool noPowerShell = (output.IndexOf("not found", StringComparison.OrdinalIgnoreCase) != -1 ||
int notFoundExitCode = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 9009 : 127;
bool noPowerShell = (exitCode == notFoundExitCode || output.IndexOf("not found", StringComparison.OrdinalIgnoreCase) != -1 ||
output.IndexOf("is not recognized", StringComparison.OrdinalIgnoreCase) != -1) &&
// If the error contains AADSTS, this should be treated as a general error to be bubbled to the user
output.IndexOf("AADSTS", StringComparison.OrdinalIgnoreCase) == -1;
Expand Down Expand Up @@ -264,7 +265,7 @@ private void GetFileNameAndArguments(string resource, string tenantId, out strin
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
fileName = Path.Combine(DefaultWorkingDirWindows, "cmd.exe");
argument = $"/d /c \"{powershellExe} \"{commandBase64}\" \"";
argument = $"/d /c \"{powershellExe} \"{commandBase64}\" \" & exit";
}
else
{
Expand Down
1 change: 1 addition & 0 deletions sdk/identity/Azure.Identity/src/ProcessRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ internal sealed class ProcessRunner : IDisposable
private readonly CancellationTokenSource _timeoutCts;
private CancellationTokenRegistration _ctRegistration;
private bool _logPII;
public int ExitCode => _process.ExitCode;

public ProcessRunner(IProcess process, TimeSpan timeout, bool logPII, CancellationToken cancellationToken)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private static IEnumerable<object[]> FallBackErrorScenarios()
yield return new object[] { "'pwsh' is not recognized", AzurePowerShellCredential.PowerShellNotInstalledError };
yield return new object[] { "pwsh: command not found", AzurePowerShellCredential.PowerShellNotInstalledError };
yield return new object[] { "pwsh: not found", AzurePowerShellCredential.PowerShellNotInstalledError };
yield return new object[] { "foo bar", AzurePowerShellCredential.PowerShellNotInstalledError };
}

[Test]
Expand All @@ -127,7 +128,9 @@ public void AuthenticateWithAzurePowerShellCredential_FallBackErrorScenarios(str
{
// This will require two processes on Windows and one on other platforms
// Purposefully stripping out the second process to ensure any attempt to fallback is caught on non-Windows
TestProcess[] testProcesses = new TestProcess[] { new TestProcess { Error = errorMessage }, new TestProcess { Error = errorMessage } };
int exitCode = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 9009 : 127;

TestProcess[] testProcesses = new TestProcess[] { new TestProcess { Error = errorMessage, CodeOnExit = exitCode }, new TestProcess { Error = errorMessage, CodeOnExit = exitCode } };
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
testProcesses = new TestProcess[] { testProcesses[0] };

Expand Down

0 comments on commit 925ea6c

Please sign in to comment.