From 47e9d542bdaa5ea7fc5ade86bebed55715565147 Mon Sep 17 00:00:00 2001
From: Sourav Gupta <souravgupta@microsoft.com>
Date: Fri, 21 Jul 2023 18:22:44 +0530
Subject: [PATCH] SAS creation fix when stored access policy is used

---
 sdk/storage/azqueue/CHANGELOG.md         |  2 +
 sdk/storage/azqueue/queue_client_test.go | 71 ++++++++++++++++++++++++
 sdk/storage/azqueue/sas/service.go       |  5 +-
 sdk/storage/azqueue/sas/service_test.go  | 45 +++++++++++++++
 4 files changed, 121 insertions(+), 2 deletions(-)

diff --git a/sdk/storage/azqueue/CHANGELOG.md b/sdk/storage/azqueue/CHANGELOG.md
index b9078af0b455..bca1028dd946 100644
--- a/sdk/storage/azqueue/CHANGELOG.md
+++ b/sdk/storage/azqueue/CHANGELOG.md
@@ -8,6 +8,8 @@
 
 #### Bugs Fixed
 
+* Fixed service SAS creation where expiry time or permissions can be omitted when stored access policy is used.
+
 #### Other Changes
 
 ### 1.0.0 (2023-05-09)
diff --git a/sdk/storage/azqueue/queue_client_test.go b/sdk/storage/azqueue/queue_client_test.go
index a009a67fb3ff..8a0191472a26 100644
--- a/sdk/storage/azqueue/queue_client_test.go
+++ b/sdk/storage/azqueue/queue_client_test.go
@@ -1456,3 +1456,74 @@ func (s *UnrecordedTestSuite) TestServiceSASDequeueMessage() {
 	_require.Equal(0, len(resp.Messages))
 	_require.Nil(err)
 }
+
+func (s *UnrecordedTestSuite) TestQueueSASUsingAccessPolicy() {
+	_require := require.New(s.T())
+
+	cred, err := testcommon.GetGenericCredential(testcommon.TestAccountDefault)
+	_require.NoError(err)
+
+	svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil)
+	_require.NoError(err)
+
+	testName := s.T().Name()
+	queueName := testcommon.GenerateQueueName(testName)
+	queueClient := testcommon.GetQueueClient(queueName, svcClient)
+	defer testcommon.DeleteQueue(context.Background(), _require, queueClient)
+
+	_, err = queueClient.Create(context.Background(), nil)
+	_require.NoError(err)
+
+	id := "testAccessPolicy"
+	ps := azqueue.AccessPolicyPermission{Read: true, Add: true, Update: true, Process: true}
+	signedIdentifiers := make([]*azqueue.SignedIdentifier, 0)
+	signedIdentifiers = append(signedIdentifiers, &azqueue.SignedIdentifier{
+		AccessPolicy: &azqueue.AccessPolicy{
+			Expiry:     to.Ptr(time.Now().Add(1 * time.Hour)),
+			Start:      to.Ptr(time.Now()),
+			Permission: to.Ptr(ps.String()),
+		},
+		ID: &id,
+	})
+
+	_, err = queueClient.SetAccessPolicy(context.Background(), &azqueue.SetAccessPolicyOptions{
+		QueueACL: signedIdentifiers,
+	})
+	_require.NoError(err)
+
+	gResp, err := queueClient.GetAccessPolicy(context.Background(), nil)
+	_require.NoError(err)
+	_require.Len(gResp.SignedIdentifiers, 1)
+
+	time.Sleep(30 * time.Second)
+
+	sasQueryParams, err := sas.QueueSignatureValues{
+		Protocol:   sas.ProtocolHTTPS,
+		Identifier: id,
+		QueueName:  queueName,
+	}.SignWithSharedKey(cred)
+	_require.NoError(err)
+
+	queueSAS := queueClient.URL() + "?" + sasQueryParams.Encode()
+	queueClientSAS, err := azqueue.NewQueueClientWithNoCredential(queueSAS, nil)
+	_require.NoError(err)
+
+	_, err = queueClientSAS.GetProperties(context.Background(), nil)
+	_require.NoError(err)
+
+	// enqueue 4 messages
+	for i := 0; i < 4; i++ {
+		_, err = queueClientSAS.EnqueueMessage(context.Background(), fmt.Sprintf("%v : %v", testcommon.QueueDefaultData, i), nil)
+		_require.NoError(err)
+	}
+
+	// dequeue 4 messages
+	for i := 0; i < 4; i++ {
+		resp, err := queueClientSAS.DequeueMessage(context.Background(), nil)
+		_require.NoError(err)
+		_require.Equal(1, len(resp.Messages))
+		_require.NotNil(resp.Messages[0].MessageText)
+		_require.Equal(fmt.Sprintf("%v : %v", testcommon.QueueDefaultData, i), *resp.Messages[0].MessageText)
+		_require.NotNil(resp.Messages[0].MessageID)
+	}
+}
diff --git a/sdk/storage/azqueue/sas/service.go b/sdk/storage/azqueue/sas/service.go
index 2a7477d95014..f42ebca43693 100644
--- a/sdk/storage/azqueue/sas/service.go
+++ b/sdk/storage/azqueue/sas/service.go
@@ -32,7 +32,7 @@ type QueueSignatureValues struct {
 
 // SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters.
 func (v QueueSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) {
-	if v.ExpiryTime.IsZero() || v.Permissions == "" {
+	if v.Identifier == "" && (v.ExpiryTime.IsZero() || v.Permissions == "") {
 		return QueryParameters{}, errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions")
 	}
 
@@ -75,7 +75,8 @@ func (v QueueSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCr
 		permissions: v.Permissions,
 		ipRange:     v.IPRange,
 		// Calculated SAS signature
-		signature: signature,
+		signature:  signature,
+		identifier: signedIdentifier,
 	}
 
 	return p, nil
diff --git a/sdk/storage/azqueue/sas/service_test.go b/sdk/storage/azqueue/sas/service_test.go
index be8d401f5bf7..64f5932ec2ad 100644
--- a/sdk/storage/azqueue/sas/service_test.go
+++ b/sdk/storage/azqueue/sas/service_test.go
@@ -7,8 +7,11 @@
 package sas
 
 import (
+	"errors"
+	"github.com/Azure/azure-sdk-for-go/sdk/storage/azqueue/internal/exported"
 	"github.com/stretchr/testify/require"
 	"testing"
+	"time"
 )
 
 func TestQueuePermissions_String(t *testing.T) {
@@ -79,3 +82,45 @@ func TestGetCanonicalName(t *testing.T) {
 		require.Equal(t, c.expected, getCanonicalName(c.inputAccount, c.inputQueue))
 	}
 }
+
+func TestQueueSignatureValues_SignWithSharedKey(t *testing.T) {
+	cred, err := exported.NewSharedKeyCredential("fakeaccountname", "AKIAIOSFODNN7EXAMPLE")
+	require.Nil(t, err, "error creating valid shared key credentials.")
+
+	expiryDate, err := time.Parse("2006-01-02", "2023-07-20")
+	require.Nil(t, err, "error creating valid expiry date.")
+
+	testdata := []struct {
+		object        QueueSignatureValues
+		expected      QueryParameters
+		expectedError error
+	}{
+		{
+			object:        QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "r", ExpiryTime: expiryDate},
+			expected:      QueryParameters{version: Version, permissions: "r", expiryTime: expiryDate},
+			expectedError: nil,
+		},
+		{
+			object:        QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "", ExpiryTime: expiryDate},
+			expected:      QueryParameters{},
+			expectedError: errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions"),
+		},
+		{
+			object:        QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "r", ExpiryTime: *new(time.Time)},
+			expected:      QueryParameters{},
+			expectedError: errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions"),
+		},
+		{
+			object:        QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "", ExpiryTime: *new(time.Time), Identifier: "fakepolicyname"},
+			expected:      QueryParameters{version: Version, identifier: "fakepolicyname"},
+			expectedError: nil,
+		},
+	}
+	for _, c := range testdata {
+		act, err := c.object.SignWithSharedKey(cred)
+		// ignore signature value
+		act.signature = ""
+		require.Equal(t, c.expected, act)
+		require.Equal(t, c.expectedError, err)
+	}
+}