From 404a34f5034a2647c512ae96a021a153f7abccd6 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Tue, 12 Mar 2024 14:44:44 -0400 Subject: [PATCH] Several updates (#161) - Fail the build earlier. - Remove "clean" from the build. Makes it faster. - Use better variable names. - Make an internal method. - Add javadoc. - Remove unused variable. - Add `iter()` and `next()` built-ins. - Handle dataset iteration. - Add summary for `tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory()`. --- .../python/ml/test/TestTensorflow2Model.java | 33 +++++++ .../data/tensorflow.xml | 30 +++++++ .../ml/client/PythonTensorAnalysisEngine.java | 89 ++++++++++++++++--- .../data/tf2_test_dataset16.py | 15 ++++ .../data/tf2_test_dataset17.py | 19 ++++ .../data/tf2_test_dataset18.py | 23 +++++ .../data/tf2_test_dataset19.py | 32 +++++++ .../PythonSSAPropagationCallGraphBuilder.java | 6 -- .../ipa/summaries/BuiltinFunctions.java | 4 + .../wala/cast/python/loader/PythonLoader.java | 2 + .../cast/python/parser/AbstractParser.java | 2 - .../wala/cast/python/types/PythonTypes.java | 4 + 12 files changed, 237 insertions(+), 22 deletions(-) create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset16.py create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset17.py create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset18.py create mode 100644 com.ibm.wala.cast.python.test/data/tf2_test_dataset19.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 99df3c932..7669febb1 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -819,12 +819,14 @@ public void testDataset2() test("tf2_test_dataset2.py", "add", 2, 2, 2, 3); } + /** This is not a legal case. */ @Test public void testDataset3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_dataset3.py", "add", 2, 2, 2, 3); } + /** This is not a legal case. */ @Test public void testDataset4() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -922,6 +924,37 @@ public void testDataset15() test("tf2_test_dataset14.py", "g", 1, 1, 2); } + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset16() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset16.py", "add", 2, 2, 2, 3); + } + + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset17() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset17.py", "add", 2, 2, 2, 3); + test("tf2_test_dataset17.py", "f", 1, 1, 2); + } + + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset18() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset18.py", "add", 2, 2, 2, 3); + test("tf2_test_dataset18.py", "f", 1, 1, 2); + test("tf2_test_dataset18.py", "g", 0, 2); + } + + /** Test a dataset that uses an iterator. */ + @Test + public void testDataset19() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_dataset19.py", "distributed_train_step", 1, 1, 2); + } + @Test public void testDataset20() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index 56486c202..0148ac7d8 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -47,6 +47,10 @@ + + + + @@ -165,6 +169,8 @@ + + @@ -748,6 +754,30 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 048a395b8..393b6c88b 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -97,6 +97,12 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine errorLog = HashMapFactory.make(); private static Set getDataflowSources( @@ -126,6 +132,37 @@ private static Set getDataflowSources( && ni.getException() != vn) { sources.add(src); logger.info("Added dataflow source from tensor generator: " + src + "."); + } else if (ni.getNumberOfUses() > 1) { + // Get the invoked function from the PA. + int target = ni.getUse(0); + PointerKey targetKey = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, target); + + for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) { + if (ik instanceof ConcreteTypeKey) { + ConcreteTypeKey ctk = (ConcreteTypeKey) ik; + IClass type = ctk.getType(); + TypeReference reference = type.getReference(); + + if (reference.equals(NEXT.getDeclaringClass())) { + // it's a call to `next()`. Look up the call to `iter()`. + int iterator = ni.getUse(1); + SSAInstruction iteratorDef = du.getDef(iterator); + + // Let's see if the iterator is over a tensor dataset. + if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) { + // Get the argument. + int iterArg = iteratorDef.getUse(1); + processInstructionInterprocedurally( + iteratorDef, iterArg, localPointerKeyNode, src, sources, pointerAnalysis); + } else + // Use the original instruction. NOTE: We can only do this because `iter()` is + // currently just passing-through its argument. + processInstructionInterprocedurally( + ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis); + } + } + } } } else if (inst instanceof EachElementGetInstruction) { // We are potentially pulling a tensor out of a tensor iterable. @@ -152,8 +189,7 @@ private static Set getDataflowSources( src, sources, callGraph, - pointerAnalysis, - newHashSet()); + pointerAnalysis); } } else if (inst instanceof PythonPropertyRead) { // We are potentially pulling a tensor out of a non-scalar tensor iterable. @@ -172,14 +208,7 @@ private static Set getDataflowSources( || def instanceof PythonPropertyRead || def instanceof PythonInvokeInstruction) { processInstruction( - def, - du, - localPointerKeyNode, - src, - sources, - callGraph, - pointerAnalysis, - newHashSet()); + def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis); } } } @@ -187,6 +216,34 @@ private static Set getDataflowSources( return sources; } + /** + * Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable} + * is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources. + * + * @param instruction The {@link SSAInstruction} to process. + * @param du The {@link DefUse} corresponding to the given {@link SSAInstruction}. + * @param node The {@link CGNode} containing the given {@link SSAInstruction}. + * @param src The {@link PointsToSetVariable} under question as to whether it should be considered + * a tensor dataflow source. + * @param sources The {@link Set} of tensor dataflow sources. + * @param callGraph The {@link CallGraph} containing the given {@link SSAInstruction}. + * @param pointerAnalysis The {@link PointerAnalysis} corresponding to the given {@link + * CallGraph}. + * @return True iff the given {@link PointsToSetVariable} was added to the given {@link Set} of + * {@link PointsToSetVariable} dataflow sources. + */ + private static boolean processInstruction( + SSAInstruction instruction, + DefUse du, + CGNode node, + PointsToSetVariable src, + Set sources, + CallGraph callGraph, + PointerAnalysis pointerAnalysis) { + return processInstruction( + instruction, du, node, src, sources, callGraph, pointerAnalysis, newHashSet()); + } + /** * Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable} * is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources. @@ -270,7 +327,7 @@ private static boolean processInstructionInterprocedurally( PointerAnalysis pointerAnalysis) { logger.info( () -> - "Using interprocedural analysis to find potential tensor iterable definition for use: " + "Using interprocedural analysis to find potential tensor definition for use: " + use + " of instruction: " + instruction @@ -300,7 +357,7 @@ private static boolean processInstructionInterprocedurally( * Returns true iff the given {@link PointsToSetVariable} refers to a tensor dataset element of * the dataset defined by the given value number in the given {@link CGNode}. * - * @param src The {@link PointsToSetVariable} to consider. + * @param variable The {@link PointsToSetVariable} to consider. * @param val The value in the given {@link CGNode} representing the tensor dataset. * @param node The {@link CGNode} containing the given {@link PointsToSetVariable} and value. * @param pointerAnalysis The {@link PointerAnalysis} that includes points-to information for the @@ -309,7 +366,10 @@ private static boolean processInstructionInterprocedurally( * val in node. */ private static boolean isDatasetTensorElement( - PointsToSetVariable src, int val, CGNode node, PointerAnalysis pointerAnalysis) { + PointsToSetVariable variable, + int val, + CGNode node, + PointerAnalysis pointerAnalysis) { SSAInstruction def = node.getDU().getDef(val); if (def instanceof PythonInvokeInstruction) { @@ -335,7 +395,8 @@ private static boolean isDatasetTensorElement( PythonPropertyRead srcDef = (PythonPropertyRead) - node.getDU().getDef(((LocalPointerKey) src.getPointerKey()).getValueNumber()); + node.getDU() + .getDef(((LocalPointerKey) variable.getPointerKey()).getValueNumber()); // What does the member reference point to? PointerKey memberRefPointerKey = diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset16.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset16.py new file mode 100644 index 000000000..f5bd602ef --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset16.py @@ -0,0 +1,15 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + +my_iter = iter(dataset) +length = len(dataset) + +for _ in range(length): + element = next(my_iter) + add(element, element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset17.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset17.py new file mode 100644 index 000000000..151553e2e --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset17.py @@ -0,0 +1,19 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +def f(a): + return add(a, a) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + +my_iter = iter(dataset) +length = len(dataset) + +for _ in range(length): + element = next(my_iter) + f(element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset18.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset18.py new file mode 100644 index 000000000..9ad0d3758 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset18.py @@ -0,0 +1,23 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +def f(a): + return add(a, a) + + +def g(a): + element = next(a) + return f(element) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + +my_iter = iter(dataset) +length = len(dataset) + +for _ in range(length): + g(my_iter) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset19.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset19.py new file mode 100644 index 000000000..db492a6e6 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset19.py @@ -0,0 +1,32 @@ +# From https://github.com/YunYang1994/TensorFlow2.0-Examples/blob/299fd6689f242d0f647a96b8844e86325e9fcb46/7-Utils/multi_gpu_train.py. + +import tensorflow as tf +from tensorflow.keras.preprocessing.image import ImageDataGenerator + + +@tf.function +def distributed_train_step(dataset_inputs): + pass + + +EPOCHS = 40 +IMG_SIZE = 112 # Input Image Size +BATCH_SIZE = 512 # Total 4 GPU, 128 batch per GPU + +train_datagen = ImageDataGenerator( + rescale=1.0 / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=False +) + +train_generator = train_datagen.flow_from_directory( + "./mnist/train", + target_size=(IMG_SIZE, IMG_SIZE), + batch_size=BATCH_SIZE, + class_mode="categorical", +) + +for epoch in range(1, EPOCHS + 1): + batchs_per_epoch = len(train_generator) + train_dataset = iter(train_generator) + + for _ in range(batchs_per_epoch): + batch_loss = distributed_train_step(next(train_dataset)) diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java index 827aa513e..59eaa0743 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java @@ -12,7 +12,6 @@ import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder; import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey; -import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions.BuiltinFunction; import com.ibm.wala.cast.python.ir.PythonLanguage; import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor; import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; @@ -40,7 +39,6 @@ import com.ibm.wala.ssa.SymbolTable; import com.ibm.wala.types.FieldReference; import com.ibm.wala.types.TypeReference; -import com.ibm.wala.util.collections.HashMapFactory; import com.ibm.wala.util.collections.Pair; import com.ibm.wala.util.intset.IntIterator; import com.ibm.wala.util.intset.IntSet; @@ -49,7 +47,6 @@ import com.ibm.wala.util.intset.OrdinalSet; import java.util.Arrays; import java.util.Collection; -import java.util.Map; import java.util.logging.Logger; public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder { @@ -98,9 +95,6 @@ protected boolean sameMethod(CGNode opNode, String definingMethod) { private static final Collection types = Arrays.asList(PythonTypes.string, TypeReference.Int); - private final Map, BuiltinFunction> primitives = - HashMapFactory.make(); - public static class PythonConstraintVisitor extends AstConstraintVisitor implements PythonInstructionVisitor { diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/summaries/BuiltinFunctions.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/summaries/BuiltinFunctions.java index 329248320..73ef701b8 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/summaries/BuiltinFunctions.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/summaries/BuiltinFunctions.java @@ -288,6 +288,10 @@ private static TypeReference builtinFunction(String name) { builtinFunctions.put("__delete__", Either.forRight(2)); // https://docs.python.org/3/library/functions.html#print builtinFunctions.put("print", Either.forLeft(TypeReference.Void)); + // https://docs.python.org/3/library/functions.html#iter + builtinFunctions.put("iter", Either.forRight(2)); + // https://docs.python.org/3/library/functions.html#next + builtinFunctions.put("next", Either.forLeft(PythonTypes.object)); } public static Set builtins() { diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/loader/PythonLoader.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/loader/PythonLoader.java index dbac73250..c82645ee1 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/loader/PythonLoader.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/loader/PythonLoader.java @@ -204,6 +204,8 @@ protected TranslatorToIR initTranslator(Set> topLe new CoreClass(PythonTypes.trampoline.getName(), PythonTypes.CodeBody.getName(), this, null); final CoreClass superfun = new CoreClass(PythonTypes.superfun.getName(), PythonTypes.CodeBody.getName(), this, null); + final CoreClass iterator = + new CoreClass(PythonTypes.iterator.getName(), PythonTypes.object.getName(), this, null); public PythonLoader(IClassHierarchy cha, IClassLoader parent) { super(cha, parent); diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/parser/AbstractParser.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/parser/AbstractParser.java index c55008c37..1c17a2f1f 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/parser/AbstractParser.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/parser/AbstractParser.java @@ -64,12 +64,10 @@ public interface PythonGlobalsEntity { "id", "input", "isinstance", - "iter", "locals", "map", "max", "min", - "next", "object", "open", "ord", diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/types/PythonTypes.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/types/PythonTypes.java index 50c324635..6a15512aa 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/types/PythonTypes.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/types/PythonTypes.java @@ -72,4 +72,8 @@ public class PythonTypes extends AstTypeReference { public static final TypeReference superfun = TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Lsuperfun")); + + /** https://docs.python.org/3/library/stdtypes.html#typeiter. */ + public static final TypeReference iterator = + TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Literator")); }