Skip to content

Commit

Permalink
feat: Memoization (#827)
Browse files Browse the repository at this point in the history
- memoizable function are now generated with
`runner_memoized_function_call`
- hidden parameters are inserted, if needed

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people committed Jan 25, 2024
1 parent 1e39300 commit d0a6c71
Show file tree
Hide file tree
Showing 42 changed files with 326 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import {
isSdsBlockLambdaResult,
isSdsCall,
isSdsCallable,
isSdsClass,
isSdsDeclaration,
isSdsEnumVariant,
isSdsExpressionLambda,
isSdsExpressionStatement,
Expand All @@ -38,6 +40,7 @@ import {
isSdsMap,
isSdsMemberAccess,
isSdsModule,
isSdsParameter,
isSdsParenthesizedExpression,
isSdsPipeline,
isSdsPlaceholder,
Expand All @@ -58,6 +61,7 @@ import {
SdsAssignment,
SdsBlock,
SdsBlockLambda,
SdsCall,
SdsDeclaration,
SdsExpression,
SdsModule,
Expand All @@ -72,10 +76,12 @@ import { isInStubFile, isStubFile } from '../helpers/fileExtensions.js';
import { IdManager } from '../helpers/idManager.js';
import {
getAbstractResults,
getArguments,
getAssignees,
getImportedDeclarations,
getImports,
getModuleMembers,
getParameters,
getPlaceholderByName,
getStatements,
Parameter,
Expand All @@ -92,7 +98,7 @@ import {
import { SafeDsPartialEvaluator } from '../partialEvaluation/safe-ds-partial-evaluator.js';
import { SafeDsServices } from '../safe-ds-module.js';
import { SafeDsPurityComputer } from '../purity/safe-ds-purity-computer.js';
import { ImpurityReason } from '../purity/model.js';
import { FileRead, ImpurityReason } from '../purity/model.js';

export const CODEGEN_PREFIX = '__gen_';
const BLOCK_LAMBDA_PREFIX = `${CODEGEN_PREFIX}block_lambda_`;
Expand Down Expand Up @@ -222,10 +228,31 @@ export class SafeDsPythonGenerator {
return mapper.toString();
}

private getPythonModuleOrDefault(object: SdsModule) {
return this.builtinAnnotations.getPythonModule(object) || object.name;
}

private getPythonNameOrDefault(object: SdsDeclaration) {
return this.builtinAnnotations.getPythonName(object) || object.name;
}

private getQualifiedNamePythonCompatible(node: SdsDeclaration | undefined): string | undefined {
const segments = [];

let current: SdsDeclaration | undefined = node;
while (current) {
const currentName = isSdsModule(current)
? this.getPythonModuleOrDefault(current)
: this.getPythonNameOrDefault(current);
if (currentName) {
segments.unshift(currentName);
}
current = getContainerOfType(current.$container, isSdsDeclaration);
}

return segments.join('.');
}

private formatGeneratedFileName(baseName: string): string {
return `gen_${this.sanitizeModuleNameForPython(baseName)}`;
}
Expand Down Expand Up @@ -634,27 +661,29 @@ export class SafeDsPythonGenerator {
return traceToNode(expression)(frame.getUniqueLambdaBlockName(expression));
} else if (isSdsCall(expression)) {
const callable = this.nodeMapper.callToCallable(expression);
if (isSdsFunction(callable)) {
const pythonCall = this.builtinAnnotations.getPythonCall(callable);
if (pythonCall) {
const sortedArgs = this.sortArguments(getArguments(expression));
// Memoize constructor or function call
if (isSdsFunction(callable) || isSdsClass(callable)) {
if (isSdsFunction(callable)) {
const pythonCall = this.builtinAnnotations.getPythonCall(callable);
if (pythonCall) {
let thisParam: CompositeGeneratorNode | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = this.generateExpression(expression.receiver.receiver, frame);
}
const argumentsMap = this.getArgumentsMap(getArguments(expression), frame);
return this.generatePythonCall(expression, pythonCall, argumentsMap, frame, thisParam);
}
}
if (this.isMemoizableCall(expression)) {
let thisParam: CompositeGeneratorNode | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = this.generateExpression(expression.receiver.receiver, frame);
}
const argumentsMap = this.getArgumentsMap(expression.argumentList.arguments, frame);
return this.generatePythonCall(expression, pythonCall, argumentsMap, thisParam);
return this.generateMemoizedCall(expression, sortedArgs, frame, thisParam);
}
}

const sortedArgs = this.sortArguments(expression.argumentList.arguments);
return expandTracedToNode(expression)`${this.generateExpression(
expression.receiver,
frame,
)}(${joinTracedToNode(expression.argumentList, 'arguments')(
sortedArgs,
(arg) => this.generateArgument(arg, frame),
{ separator: ', ' },
)})`;
return this.generatePlainCall(expression, sortedArgs, frame);
} else if (isSdsExpressionLambda(expression)) {
return expandTracedToNode(expression)`lambda ${this.generateParameters(
expression.parameterList,
Expand Down Expand Up @@ -753,18 +782,30 @@ export class SafeDsPythonGenerator {
throw new Error(`Unknown expression type: ${expression.$type}`);
}

private generatePlainCall(
expression: SdsCall,
sortedArgs: SdsArgument[],
frame: GenerationInfoFrame,
): CompositeGeneratorNode {
return expandTracedToNode(expression)`${this.generateExpression(expression.receiver, frame)}(${joinTracedToNode(
expression.argumentList,
'arguments',
)(sortedArgs, (arg) => this.generateArgument(arg, frame), { separator: ', ' })})`;
}

private generatePythonCall(
expression: SdsExpression,
expression: SdsCall,
pythonCall: string,
argumentsMap: Map<string, CompositeGeneratorNode>,
frame: GenerationInfoFrame,
thisParam: CompositeGeneratorNode | undefined = undefined,
): CompositeGeneratorNode {
if (thisParam) {
argumentsMap.set('this', thisParam);
}
const splitRegex = /(\$[_a-zA-Z][_a-zA-Z0-9]*)/gu;
const splitPythonCallDefinition = pythonCall.split(splitRegex);
return joinTracedToNode(expression)(
const generatedPythonCall = joinTracedToNode(expression)(
splitPythonCallDefinition,
(part) => {
if (splitRegex.test(part)) {
Expand All @@ -775,6 +816,115 @@ export class SafeDsPythonGenerator {
},
{ separator: '' },
)!;
// Non-memoizable calls can be directly generated
if (!this.isMemoizableCall(expression)) {
return generatedPythonCall;
}
frame.addImport({ importPath: RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE });
const hiddenParameters = this.getMemoizedCallHiddenParameters(expression, frame);
const callable = this.nodeMapper.callToCallable(expression);
const memoizedArgs = getParameters(callable).map(
(parameter) => this.nodeMapper.callToParameterValue(expression, parameter)!,
);
return expandTracedToNode(
expression,
)`${RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE}.runner_memoized_function_call("${this.generateFullyQualifiedFunctionName(
expression,
)}", lambda *_ : ${generatedPythonCall}, [${joinTracedToNode(expression.argumentList, 'arguments')(
memoizedArgs,
(arg) => this.generateExpression(arg, frame),
{ separator: ', ' },
)}], [${joinToNode(hiddenParameters, (param) => param, { separator: ', ' })}])`;
}

private isMemoizableCall(expression: SdsCall): boolean {
const impurityReasons = this.purityComputer.getImpurityReasonsForExpression(expression);
// If the file is not known, the call is not memoizable
return !impurityReasons.some((reason) => !(reason instanceof FileRead) || reason.path === undefined);
}

private generateMemoizedCall(
expression: SdsCall,
sortedArgs: SdsArgument[],
frame: GenerationInfoFrame,
thisParam: CompositeGeneratorNode | undefined = undefined,
): CompositeGeneratorNode {
frame.addImport({ importPath: RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE });
const hiddenParameters = this.getMemoizedCallHiddenParameters(expression, frame);
const memoizedArgs = getParameters(this.nodeMapper.callToCallable(expression)).map(
(parameter) => this.nodeMapper.callToParameterValue(expression, parameter)!,
);
const containsOptionalArgs = sortedArgs.some((arg) =>
Parameter.isOptional(this.nodeMapper.argumentToParameter(arg)),
);
const fullyQualifiedTargetName = this.generateFullyQualifiedFunctionName(expression);
return expandTracedToNode(
expression,
)`${RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE}.runner_memoized_function_call("${fullyQualifiedTargetName}", ${
containsOptionalArgs ? 'lambda *_ : ' : ''
}${
containsOptionalArgs
? this.generatePlainCall(expression, sortedArgs, frame)
: isSdsMemberAccess(expression.receiver) && isSdsCall(expression.receiver.receiver)
? expandTracedToNode(expression.receiver)`${this.generateExpression(
expression.receiver.receiver.receiver,
frame,
)}.${this.generateExpression(expression.receiver.member!, frame)}`
: this.generateExpression(expression.receiver, frame)
}, [${thisParam ? thisParam.append(', ') : ''}${joinTracedToNode(expression.argumentList, 'arguments')(
memoizedArgs,
(arg) => this.generateExpression(arg, frame),
{
separator: ', ',
},
)}], [${joinToNode(hiddenParameters, (param) => param, { separator: ', ' })}])`;
}

private getMemoizedCallHiddenParameters(expression: SdsCall, frame: GenerationInfoFrame): CompositeGeneratorNode[] {
const impurityReasons = this.purityComputer.getImpurityReasonsForExpression(expression);
const hiddenParameters: CompositeGeneratorNode[] = [];
for (const reason of impurityReasons) {
if (reason instanceof FileRead) {
if (typeof reason.path === 'string') {
hiddenParameters.push(
expandTracedToNode(
expression,
)`${RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE}.runner_filemtime('${reason.path}')`,
);
} else if (isSdsParameter(reason.path)) {
const argument = this.nodeMapper
.parametersToArguments([reason.path], getArguments(expression))
.get(reason.path);
if (!argument) {
/* c8 ignore next 4 */
throw new Error(
'File Read impurity with dependency on parameter is present on call, but no argument has been provided.',
);
}
hiddenParameters.push(
expandTracedToNode(
argument,
)`${RUNNER_SERVER_PIPELINE_MANAGER_PACKAGE}.runner_filemtime(${this.generateArgument(
argument,
frame,
)})`,
);
}
}
}
return hiddenParameters;
}

private generateFullyQualifiedFunctionName(expression: SdsCall): string {
const callable = this.nodeMapper.callToCallable(expression);
if (isSdsDeclaration(callable)) {
const fullyQualifiedReferenceName = this.getQualifiedNamePythonCompatible(callable);
if (fullyQualifiedReferenceName) {
return fullyQualifiedReferenceName;
}
}
/* c8 ignore next */
throw new Error('Callable of provided call does not exist or is not a declaration.');
}

private getArgumentsMap(
Expand Down Expand Up @@ -803,10 +953,14 @@ export class SafeDsPythonGenerator {
.map((value) => value.arg);
}

private generateArgument(argument: SdsArgument, frame: GenerationInfoFrame): CompositeGeneratorNode {
private generateArgument(
argument: SdsArgument,
frame: GenerationInfoFrame,
generateOptionalParameterName: boolean = true,
): CompositeGeneratorNode {
const parameter = this.nodeMapper.argumentToParameter(argument);
return expandTracedToNode(argument)`${
parameter !== undefined && !Parameter.isRequired(parameter)
parameter !== undefined && !Parameter.isRequired(parameter) && generateOptionalParameterName
? expandToNode`${this.generateParameter(parameter, frame, false)}=`
: ''
}${this.generateExpression(argument.value, frame)}`;
Expand All @@ -830,13 +984,13 @@ export class SafeDsPythonGenerator {
if (declaration === importedDeclaration.declaration?.ref) {
if (importedDeclaration.alias !== undefined) {
return {
importPath: this.builtinAnnotations.getPythonModule(targetModule) || value.package,
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: importedDeclaration.declaration?.ref?.name,
alias: importedDeclaration.alias.alias,
};
} else {
return {
importPath: this.builtinAnnotations.getPythonModule(targetModule) || value.package,
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: importedDeclaration.declaration?.ref?.name,
};
}
Expand All @@ -845,7 +999,7 @@ export class SafeDsPythonGenerator {
}
if (isSdsWildcardImport(value)) {
return {
importPath: this.builtinAnnotations.getPythonModule(targetModule) || value.package,
importPath: this.getPythonModuleOrDefault(targetModule),
declarationName: declaration.name,
};
}
Expand All @@ -862,9 +1016,9 @@ export class SafeDsPythonGenerator {
const targetModule = <SdsModule>findRootNode(declaration);
if (currentModule !== targetModule && !isInStubFile(targetModule)) {
return {
importPath: `${
this.builtinAnnotations.getPythonModule(targetModule) || targetModule.name
}.${this.formatGeneratedFileName(this.getModuleFileBaseName(targetModule))}`,
importPath: `${this.getPythonModuleOrDefault(targetModule)}.${this.formatGeneratedFileName(
this.getModuleFileBaseName(targetModule),
)}`,
declarationName: this.getPythonNameOrDefault(declaration),
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Imports ----------------------------------------------------------------------

import safeds_runner.server.pipeline_manager

# Segments ---------------------------------------------------------------------

def f1(l):
Expand All @@ -11,11 +15,11 @@ def f2(l):

def test():
def __gen_block_lambda_0(a, b):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambdaResult.g", g, [], [])
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_e = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambdaResult.g", g, [], [])
__gen_block_lambda_result_e = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambdaResult.g", g, [], [])
return __gen_block_lambda_result_d, __gen_block_lambda_result_e
f2(__gen_block_lambda_1)

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@

def test():
def __gen_block_lambda_0(a, b=2):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.g", g, [], [])
return __gen_block_lambda_result_d
f1(__gen_block_lambda_0)
def __gen_block_lambda_1(a, b):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.g", g, [], [])
return __gen_block_lambda_result_d
f1(__gen_block_lambda_1)
def __gen_block_lambda_2():
pass
f2(__gen_block_lambda_2)
def __gen_block_lambda_3(a, b=2):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.g", g, [], [])
return __gen_block_lambda_result_d
g2(f3(__gen_block_lambda_3))
g2(safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.f3", f3, [__gen_block_lambda_3], []))
def __gen_block_lambda_4(a, b=2):
__gen_block_lambda_result_d = g()
__gen_block_lambda_result_d = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.g", g, [], [])
return __gen_block_lambda_result_d
c = f3(__gen_block_lambda_4)
c = safeds_runner.server.pipeline_manager.runner_memoized_function_call("tests.generator.blockLambda.f3", f3, [__gen_block_lambda_4], [])
safeds_runner.server.pipeline_manager.runner_save_placeholder('c', c)
g2(c)
Loading

0 comments on commit d0a6c71

Please sign in to comment.