Skip to content

Commit

Permalink
refine factory CTOR logic (#1496)
Browse files Browse the repository at this point in the history
* refine factory CTOR logic

* add factory-gather-all-params option

* generate test cases

* refine option ctor

* fix test errors

* fix test errors

* generate test cases

* fix test errors

* special case for operationsClient

* update test cases

* update changelog and readme

* some small tweaks

---------

Co-authored-by: Joel Hendrix <jhendrix@microsoft.com>
  • Loading branch information
JiaqiZhang-Dev and jhendrixMSFT authored Feb 21, 2025
1 parent af5a817 commit 4ee31e1
Show file tree
Hide file tree
Showing 19 changed files with 422 additions and 321 deletions.
2 changes: 1 addition & 1 deletion packages/autorest.go/.scripts/regeneration.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ const keyvault = './swagger/specification/keyvault/data-plane/readme.md';
generateFromReadme("azkeyvault", keyvault, 'package-7.2', 'test/keyvault/azkeyvault', '--module=azkeyvault --module-version=0.1.0');

const consumption = './swagger/specification/consumption/resource-manager/readme.md';
generateFromReadme("armconsumption", consumption, 'package-2019-10', 'test/consumption/armconsumption', '--module=armconsumption --module-version=1.0.0 --azure-arm=true --generate-fakes=false --inject-spans=false --remove-unreferenced-types');
generateFromReadme("armconsumption", consumption, 'package-2019-10', 'test/consumption/armconsumption', '--module=armconsumption --module-version=1.0.0 --azure-arm=true --generate-fakes=false --inject-spans=false --remove-unreferenced-types --factory-gather-all-params=true');

const databoxedge = './swagger/specification/databoxedge/resource-manager/readme.md';
generateFromReadme("armdataboxedge", databoxedge, 'package-2021-02-01', 'test/databoxedge/armdataboxedge', '--module=armdataboxedge --module-version=2.0.0 --azure-arm=true --remove-unreferenced-types --inject-spans=false --fix-const-stuttering=true');
Expand Down
6 changes: 6 additions & 0 deletions packages/autorest.go/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release History

## 4.0.0-preview.71 (unreleased)

### Other Changes

* Added switch `--factory-gather-all-params` to control the `NewClientFactory` constructor parameters. This switch allows gathering either only common parameters of clients or all parameters of clients.

## 4.0.0-preview.70 (unreleased)

### Other Changes
Expand Down
3 changes: 3 additions & 0 deletions packages/autorest.go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,7 @@ help-content:
- key: fix-const-stuttering
type: boolean
description: When true, fix stuttering for const types and their values.
- key: factory-gather-all-params
type: boolean
description: When true, the NewClientFactory constructor will gather all parameters of clients. When false, the NewClientFactory constructor will only gather common parameters of clients. The default value is false.
```
2 changes: 2 additions & 0 deletions packages/autorest.go/src/generator/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export async function generateCode(host: AutorestExtensionHost) {
});
}

const factoryGatherAllParams = await session.getValue('factory-gather-all-params', false);
session.model.options.factoryGatherAllParams = factoryGatherAllParams;
const clientFactory = await generateClientFactory(session.model);
if (clientFactory.length > 0) {
host.writeFile({
Expand Down
530 changes: 263 additions & 267 deletions packages/autorest.go/test/network/armnetwork/zz_client_factory.go

Large diffs are not rendered by default.

48 changes: 40 additions & 8 deletions packages/codegen.go/src/clientFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,24 @@ export async function generateClientFactory(codeModel: go.CodeModel): Promise<st
// the list of packages to import
const imports = new ImportManager();

const allClientParams = helpers.getAllClientParameters(codeModel);
let clientFactoryParams: Array<go.Parameter>;
if (codeModel.options.factoryGatherAllParams) {
clientFactoryParams = helpers.getAllClientParameters(codeModel);
} else {
clientFactoryParams = helpers.getCommonClientParameters(codeModel);
}

const clientFactoryParamsMap = new Map<string, go.Parameter>();
for (const param of clientFactoryParams) {
clientFactoryParamsMap.set(param.name, param);
}

// add factory type
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore');
result += '// ClientFactory is a client factory used to create any client in this module.\n';
result += '// Don\'t use this type directly, use NewClientFactory instead.\n';
result += 'type ClientFactory struct {\n';
for (const clientParam of values(allClientParams)) {
for (const clientParam of values(clientFactoryParams)) {
result += `\t${clientParam.name} ${helpers.formatParameterTypeName(clientParam)}\n`;
}
result += '\tinternal *arm.Client\n';
Expand All @@ -36,19 +46,19 @@ export async function generateClientFactory(codeModel: go.CodeModel): Promise<st
imports.add('github.com/Azure/azure-sdk-for-go/sdk/azcore/arm');
result += '// NewClientFactory creates a new instance of ClientFactory with the specified values.\n';
result += '// The parameter values will be propagated to any client created from this factory.\n';
for (const clientParam of values(allClientParams)) {
for (const clientParam of values(clientFactoryParams)) {
result += helpers.formatCommentAsBulletItem(clientParam.name, clientParam.docs);
}
result += helpers.formatCommentAsBulletItem('credential', {summary: 'used to authorize requests. Usually a credential from azidentity.'});
result += helpers.formatCommentAsBulletItem('options', {summary: 'pass nil to accept the default values.'});

result += `func NewClientFactory(${allClientParams.map(param => { return `${param.name} ${helpers.formatParameterTypeName(param)}`; }).join(', ')}${allClientParams.length>0 ? ',' : ''} credential azcore.TokenCredential, options *arm.ClientOptions) (*ClientFactory, error) {\n`;
result += `func NewClientFactory(${clientFactoryParams.map(param => { return `${param.name} ${helpers.formatParameterTypeName(param)}`; }).join(', ')}${clientFactoryParams.length>0 ? ',' : ''} credential azcore.TokenCredential, options *arm.ClientOptions) (*ClientFactory, error) {\n`;
result += '\tinternal, err := arm.NewClient(moduleName, moduleVersion, credential, options)\n';
result += '\tif err != nil {\n';
result += '\t\treturn nil, err\n';
result += '\t}\n';
result += '\treturn &ClientFactory{\n';
for (const clientParam of values(allClientParams)) {
for (const clientParam of values(clientFactoryParams)) {
result += `\t\t${clientParam.name}: ${clientParam.name},\n`;
}
result += '\t\tinternal: internal,\n';
Expand All @@ -57,14 +67,36 @@ export async function generateClientFactory(codeModel: go.CodeModel): Promise<st

// add new sub client method for all operation groups
for (const client of codeModel.clients) {
const clientPrivateParams = new Array<go.Parameter>();
const clientCommonParams = new Array<go.Parameter>();
for (const param of client.parameters) {
if (clientFactoryParamsMap.has(param.name)) {
clientCommonParams.push(param);
} else {
clientPrivateParams.push(param);
}
}

const ctorName = `New${client.name}`;
result += `// ${ctorName} creates a new instance of ${client.name}.\n`;
result += `func (c *ClientFactory) ${ctorName}() *${client.name} {\n`;
result += `func (c *ClientFactory) ${ctorName}(`;
if (clientPrivateParams.length > 0) {
result += `${clientPrivateParams.map(param => {
return `${param.name} ${helpers.formatParameterTypeName(param)}`;
}).join(', ')}`;
}
result += `) *${client.name} {\n`;
result += `\treturn &${client.name}{\n`;

// some clients (e.g. operations client) don't utilize the client params
if (client.parameters.length > 0) {
for (const clientParam of values(allClientParams)) {
if (clientPrivateParams.length > 0) {
for (const clientParam of values(clientPrivateParams)) {
result += `\t\t${clientParam.name}: ${clientParam.name},\n`;
}
}

if (clientCommonParams.length > 0) {
for (const clientParam of values(clientCommonParams)) {
result += `\t\t${clientParam.name}: c.${clientParam.name},\n`;
}
}
Expand Down
33 changes: 26 additions & 7 deletions packages/codegen.go/src/example.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@ export async function generateExamples(codeModel: go.CodeModel): Promise<Array<E
imports.add(codeModel.options.module!.name);
}

const allClientParams = helpers.getAllClientParameters(codeModel);
let clientFactoryParams = new Array<go.Parameter>();
if (codeModel.options.factoryGatherAllParams) {
clientFactoryParams = helpers.getAllClientParameters(codeModel);
} else {
clientFactoryParams = helpers.getCommonClientParameters(codeModel);
}
const clientFactoryParamsMap = new Map<string, go.Parameter>();
for (const param of clientFactoryParams) {
clientFactoryParamsMap.set(param.name, param);
}

let exampleText = '';
for (const method of client.methods) {
Expand Down Expand Up @@ -77,20 +86,30 @@ export async function generateExamples(codeModel: go.CodeModel): Promise<Array<E
let clientRef = '';
if (azureARM) {
// since not all operation has all the client factory required parameters, we need to fake for the missing ones
const clientFactoryParams: go.ParameterExample[] = [];
for (const clientParam of allClientParams) {
const clientFactoryParamsExample: go.ParameterExample[] = [];
for (const clientParam of clientFactoryParams) {
const clientFactoryParam = clientParameters.find(p => p.parameter.name === clientParam.name);
if (clientFactoryParam) {
clientFactoryParams.push(clientFactoryParam);
clientFactoryParamsExample.push(clientFactoryParam);
} else {
clientFactoryParams.push({ parameter: clientParam, value: generateFakeExample(clientParam.type, clientParam.name) });
clientFactoryParamsExample.push({ parameter: clientParam, value: generateFakeExample(clientParam.type, clientParam.name) });
}
}
exampleText += `\tclientFactory, err := ${codeModel.packageName}.NewClientFactory(${clientFactoryParams.map(p => getExampleValue(codeModel, p.value, '\t', imports, helpers.parameterByValue(p.parameter)).slice(1)).join(', ')}${clientFactoryParams.length > 0 ? ', ' : ''}cred, nil)\n`;
exampleText += `\tclientFactory, err := ${codeModel.packageName}.NewClientFactory(${clientFactoryParamsExample.map(p => getExampleValue(codeModel, p.value, '\t', imports, helpers.parameterByValue(p.parameter)).slice(1)).join(', ')}${clientFactoryParams.length > 0 ? ', ' : ''}cred, nil)\n`;
exampleText += `\tif err != nil {\n`;
exampleText += `\t\tlog.Fatalf("failed to create client: %v", err)\n`;
exampleText += `\t}\n`;
clientRef = `clientFactory.${client.constructors[0]?.name}()`;
clientRef = `clientFactory.${client.constructors[0]?.name}(`;
const clientPrivateParameters: go.ParameterExample[] = [];
for (const clientParam of clientParameters) {
if (!clientFactoryParamsMap.has(clientParam.parameter.name)) {
clientPrivateParameters.push(clientParam);
}
}
if (clientPrivateParameters.length > 0) {
clientRef += `${clientPrivateParameters.map(p => getExampleValue(codeModel, p.value, '\t', imports, helpers.parameterByValue(p.parameter)).slice(1)).join(', ')}`;
}
clientRef += `)`;
} else {
exampleText += `\tclient, err := ${codeModel.packageName}.${client.constructors[0]?.name}(${clientParameters.map(p => getExampleValue(codeModel, p.value, '\t', imports, helpers.parameterByValue(p.parameter)).slice(1)).join(', ')}, cred, nil)\n`;
exampleText += `\tif err != nil {\n`;
Expand Down
34 changes: 34 additions & 0 deletions packages/codegen.go/src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,37 @@ export function getAllClientParameters(codeModel: go.CodeModel): Array<go.Parame
allClientParams.sort(sortParametersByRequired);
return allClientParams;
}

// returns common client parameters for all the clients
export function getCommonClientParameters(codeModel: go.CodeModel): Array<go.Parameter> {
const paramCount = new Map<string, { uses: number, param: go.Parameter }>();
let numClients = 0; // track client count since we might skip some
for (const clients of codeModel.clients) {
// special cases: some ARM clients always don't contain any parameters (OperationsClient will be depracated in the future)
if (codeModel.type === 'azure-arm' && clients.name.match(/^OperationsClient$/)) {
continue;
}

++numClients;
for (const clientParam of values(clients.parameters)) {
let entry = paramCount.get(clientParam.name);
if (!entry) {
entry = { uses: 0, param: clientParam };
paramCount.set(clientParam.name, entry);
}

++entry.uses;
}
}

// for each param, if its usage count is equal to the
// number of clients, then it's common to all clients
const commonClientParams = new Array<go.Parameter>();
for (const entry of paramCount.values()) {
if (entry.uses === numClients) {
commonClientParams.push(entry.param);
}
}

return commonClientParams.sort(sortParametersByRequired);
}
3 changes: 3 additions & 0 deletions packages/codemodel.go/src/package.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ export interface Options {
sliceElementsByval: boolean;

generateExamples: boolean;

// whether or not to gather all client parameters for the client factory.
factoryGatherAllParams: boolean;
}

export interface Module {
Expand Down
4 changes: 2 additions & 2 deletions packages/typespec-go/.scripts/tspcompile.js
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ const armdatabasewatcher = pkgRoot + 'test/tsp/DatabaseWatcher.Management';
generate('armdatabasewatcher', armdatabasewatcher, 'test/local/armdatabasewatcher', ['fix-const-stuttering=false']);

const armloadtestservice = pkgRoot + 'test/tsp/LoadTestService.Management';
generate('armloadtestservice', armloadtestservice, 'test/local/armloadtestservice');
generate('armloadtestservice', armloadtestservice, 'test/local/armloadtestservice', ['factory-gather-all-params=true']);

const armdevopsinfrastructure = pkgRoot + 'test/tsp/Microsoft.DevOpsInfrastructure';
generate('armdevopsinfrastructure', armdevopsinfrastructure, 'test/local/armdevopsinfrastructure');
Expand Down Expand Up @@ -217,7 +217,7 @@ function generate(moduleName, input, outputDir, perTestOptions) {
'head-as-boolean=true',
'fix-const-stuttering=true',
`examples-directory=${input}/examples`,
'generate-examples=true'
'generate-examples=true',
];

let allOptions = fixedOptions;
Expand Down
4 changes: 4 additions & 0 deletions packages/typespec-go/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

* Remove filtering Azure core model since some instances of template model is in `Azure.Core` namespace. Logic of filtering exception model could cover the filtering needs.

### Other Changes

* Added switch `--factory-gather-all-params` to control the `NewClientFactory` constructor parameters. This switch allows gathering either only common parameters of clients or all parameters of clients.

## 0.3.7 (2025-02-11)

### Bugs Fixed
Expand Down
2 changes: 2 additions & 0 deletions packages/typespec-go/src/lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export interface GoEmitterOptions {
'stutter'?: string;
'fix-const-stuttering'?: boolean;
'generate-examples'?: boolean;
'factory-gather-all-params'?: boolean;
}

const EmitterOptionsSchema: JSONSchemaType<GoEmitterOptions> = {
Expand All @@ -40,6 +41,7 @@ const EmitterOptionsSchema: JSONSchemaType<GoEmitterOptions> = {
'stutter': { type: 'string', nullable: true },
'fix-const-stuttering': { type: 'boolean', nullable: true },
'generate-examples': { type: 'boolean', nullable: true },
'factory-gather-all-params': { type: 'boolean', nullable: true },
},
required: [],
};
Expand Down
3 changes: 3 additions & 0 deletions packages/typespec-go/src/tcgcadapter/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ export async function tcgcToGoCodeModel(context: EmitContext<GoEmitterOptions>):
if (context.options['slice-elements-byval']) {
codeModel.options.sliceElementsByval = true;
}
if (context.options['factory-gather-all-params']) {
codeModel.options.factoryGatherAllParams = true;
}

fixStutteringTypeNames(sdkContext.sdkPackage, codeModel, context.options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func TestLocationResourcesClient_CreateOrUpdate(t *testing.T) {
resp, err := clientFactory.NewLocationResourcesClient().CreateOrUpdate(context.Background(), "eastus", "resource", resources.LocationResource{
resp, err := clientFactory.NewLocationResourcesClient(subscriptionIdExpected).CreateOrUpdate(context.Background(), "eastus", "resource", resources.LocationResource{
Properties: &resources.LocationResourceProperties{
Description: to.Ptr("valid"),
},
Expand All @@ -40,13 +40,13 @@ func TestLocationResourcesClient_CreateOrUpdate(t *testing.T) {
}

func TestLocationResourcesClient_Delete(t *testing.T) {
resp, err := clientFactory.NewLocationResourcesClient().Delete(context.Background(), "eastus", "resource", nil)
resp, err := clientFactory.NewLocationResourcesClient(subscriptionIdExpected).Delete(context.Background(), "eastus", "resource", nil)
require.NoError(t, err)
require.Zero(t, resp)
}

func TestLocationResourcesClient_Get(t *testing.T) {
resp, err := clientFactory.NewLocationResourcesClient().Get(context.Background(), "eastus", "resource", nil)
resp, err := clientFactory.NewLocationResourcesClient(subscriptionIdExpected).Get(context.Background(), "eastus", "resource", nil)
require.NoError(t, err)
require.Equal(t, resources.LocationResource{
Properties: &resources.LocationResourceProperties{
Expand All @@ -68,7 +68,7 @@ func TestLocationResourcesClient_Get(t *testing.T) {
}

func TestLocationResourcesClient_NewListByScopePager(t *testing.T) {
pager := clientFactory.NewLocationResourcesClient().NewListByLocationPager("eastus", nil)
pager := clientFactory.NewLocationResourcesClient(subscriptionIdExpected).NewListByLocationPager("eastus", nil)
pageCount := 0
for pager.More() {
page, err := pager.NextPage(context.Background())
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestLocationResourcesClient_NewListByScopePager(t *testing.T) {
}

func TestLocationResourcesClient_Update(t *testing.T) {
resp, err := clientFactory.NewLocationResourcesClient().Update(context.Background(), "eastus", "resource", resources.LocationResource{
resp, err := clientFactory.NewLocationResourcesClient(subscriptionIdExpected).Update(context.Background(), "eastus", "resource", resources.LocationResource{
Properties: &resources.LocationResourceProperties{
Description: to.Ptr("valid2"),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ var (
)

func TestNestedClient_Get(t *testing.T) {
nestedClientGetResponse, err := clientFactory.NewNestedClient().Get(
nestedClientGetResponse, err := clientFactory.NewNestedClient(subscriptionIdExpected).Get(
ctx,
"test-rg",
"top",
Expand All @@ -50,7 +50,7 @@ func TestNestedClient_Get(t *testing.T) {
}

func TestNestedClient_BeginCreateOrReplace(t *testing.T) {
nestedClientCreateOrReplaceResponsePoller, err := clientFactory.NewNestedClient().BeginCreateOrReplace(
nestedClientCreateOrReplaceResponsePoller, err := clientFactory.NewNestedClient(subscriptionIdExpected).BeginCreateOrReplace(
ctx,
"test-rg",
"top",
Expand All @@ -73,7 +73,7 @@ func TestNestedClient_BeginCreateOrReplace(t *testing.T) {
}

func TestNestedClient_BeginUpdate(t *testing.T) {
nestedClientUpdateResponsePoller, err := clientFactory.NewNestedClient().BeginUpdate(
nestedClientUpdateResponsePoller, err := clientFactory.NewNestedClient(subscriptionIdExpected).BeginUpdate(
ctx,
"test-rg",
"top",
Expand All @@ -96,15 +96,15 @@ func TestNestedClient_BeginUpdate(t *testing.T) {
}

func TestNestedClient_BeginDelete(t *testing.T) {
nestedClientDeleteResponsePoller, err := clientFactory.NewNestedClient().BeginDelete(ctx, "test-rg", "top", "nested", nil)
nestedClientDeleteResponsePoller, err := clientFactory.NewNestedClient(subscriptionIdExpected).BeginDelete(ctx, "test-rg", "top", "nested", nil)
require.NoError(t, err)
nestedClientDeleteResponse, err := nestedClientDeleteResponsePoller.Poll(ctx)
require.NoError(t, err)
require.Equal(t, http.StatusNoContent, nestedClientDeleteResponse.StatusCode)
}

func TestNestedClient_NewListByTopLevelTrackedResourcePager(t *testing.T) {
nestedClientListByTopLevelTrackedResourceResponsePager := clientFactory.NewNestedClient().NewListByTopLevelTrackedResourcePager("test-rg", "top", nil)
nestedClientListByTopLevelTrackedResourceResponsePager := clientFactory.NewNestedClient(subscriptionIdExpected).NewListByTopLevelTrackedResourcePager("test-rg", "top", nil)
require.True(t, nestedClientListByTopLevelTrackedResourceResponsePager.More())
nestedClientListByTopLevelTrackedResourceResponse, err := nestedClientListByTopLevelTrackedResourceResponsePager.NextPage(ctx)
require.NoError(t, err)
Expand Down
Loading

0 comments on commit 4ee31e1

Please sign in to comment.