Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several updates #161

Merged
merged 21 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -819,12 +819,14 @@ public void testDataset2()
test("tf2_test_dataset2.py", "add", 2, 2, 2, 3);
}

/** This is not a legal case. */
@Test
public void testDataset3()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset3.py", "add", 2, 2, 2, 3);
}

/** This is not a legal case. */
@Test
public void testDataset4()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Expand Down Expand Up @@ -922,6 +924,37 @@ public void testDataset15()
test("tf2_test_dataset14.py", "g", 1, 1, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset16()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset16.py", "add", 2, 2, 2, 3);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset17()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset17.py", "add", 2, 2, 2, 3);
test("tf2_test_dataset17.py", "f", 1, 1, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset18()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset18.py", "add", 2, 2, 2, 3);
test("tf2_test_dataset18.py", "f", 1, 1, 2);
test("tf2_test_dataset18.py", "g", 0, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset19()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset19.py", "distributed_train_step", 1, 1, 2);
}

@Test
public void testDataset20()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Expand Down
30 changes: 30 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
<putfield class="LRoot" field="layers" fieldType="LRoot" ref="keras" value="layers" />
<new def="models" class="Lobject" />
<putfield class="LRoot" field="models" fieldType="LRoot" ref="keras" value="models" />
<new def="preprocessing" class="Lobject" />
<putfield class="LRoot" field="preprocessing" fieldType="LRoot" ref="keras" value="preprocessing" />
<new def="image" class="Lobject" />
<putfield class="LRoot" field="image" fieldType="LRoot" ref="preprocessing" value="image" />
<new def="app" class="Lobject" />
<putfield class="LRoot" field="app" fieldType="LRoot" ref="x" value="app" />
<new def="run" class="Ltensorflow/app/run" />
Expand Down Expand Up @@ -165,6 +169,8 @@
<new def="Model" class="Ltensorflow/keras/models/Model" />
<putfield class="LRoot" field="Model" fieldType="LRoot" ref="keras" value="Model" />
<putfield class="LRoot" field="Model" fieldType="LRoot" ref="models" value="Model" />
<new def="ImageDataGenerator" class="Ltensorflow/keras/preprocessing/image/ImageDataGenerator" />
<putfield class="LRoot" field="ImageDataGenerator" fieldType="LRoot" ref="image" value="ImageDataGenerator" />
<new def="Variable" class="Ltensorflow/functions/Variable" />
<putfield class="LRoot" field="Variable" fieldType="LRoot" ref="x" value="Variable" />
<putfield class="LRoot" field="Variable" fieldType="LRoot" ref="variables" value="Variable" />
Expand Down Expand Up @@ -748,6 +754,30 @@
</method>
</class>
</package>
<package name="tensorflow/keras/preprocessing/image">
<class name="ImageDataGenerator" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="flow_from_directory" class="Ltensorflow/keras/preprocessing/image/flow_from_directory" />
<putfield class="LRoot" field="flow_from_directory" fieldType="LRoot" ref="arg0" value="flow_from_directory" />
<return value="arg0" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="24"
paramNames="self featurewise_center samplewise_center featurewise_std_normalization samplewise_std_normalization zca_whitening zca_epsilon rotation_range width_shift_range height_shift_range brightness_range shear_range zoom_range channel_shift_range fill_mode cval horizontal_flip vertical_flip rescale preprocessing_function data_format validation_split interpolation_order dtype">
<!-- NOTE: Workaround for https://github.com/wala/ML/issues/127. This ctor doesn't really return a dataset but rather the instance methods do. It shouldn't be a problem since you can't iterate over an `ImageDataGenerator`. -->
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>
<class name="flow_from_directory" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator#flow_from_directory -->
<method name="do" descriptor="()LRoot;" numArgs="16" paramNames="self directory target_size color_mode classes class_mode batch_size shuffle seed save_to_dir save_prefix save_format follow_links subset interpolation keep_aspect_ratio">
<new def="x" class="Ltensorflow/data/Dataset" />
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
<return value="xx" />
</method>
</class>
</package>
<package name="tensorflow/keras/models">
<class name="Model" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeA
PythonTypes.pythonLoader, TypeName.string2TypeName("Lwala/builtin/enumerate")),
AstMethodReference.fnSelector);

private static final MethodReference NEXT =
MethodReference.findOrCreate(
TypeReference.findOrCreate(
PythonTypes.pythonLoader, TypeName.string2TypeName("Lwala/builtin/next")),
AstMethodReference.fnSelector);

private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();

private static Set<PointsToSetVariable> getDataflowSources(
Expand Down Expand Up @@ -126,6 +132,37 @@ private static Set<PointsToSetVariable> getDataflowSources(
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source from tensor generator: " + src + ".");
} else if (ni.getNumberOfUses() > 1) {
// Get the invoked function from the PA.
int target = ni.getUse(0);
PointerKey targetKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, target);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) {
if (ik instanceof ConcreteTypeKey) {
ConcreteTypeKey ctk = (ConcreteTypeKey) ik;
IClass type = ctk.getType();
TypeReference reference = type.getReference();

if (reference.equals(NEXT.getDeclaringClass())) {
// it's a call to `next()`. Look up the call to `iter()`.
int iterator = ni.getUse(1);
SSAInstruction iteratorDef = du.getDef(iterator);

// Let's see if the iterator is over a tensor dataset.
if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) {
// Get the argument.
int iterArg = iteratorDef.getUse(1);
processInstructionInterprocedurally(
iteratorDef, iterArg, localPointerKeyNode, src, sources, pointerAnalysis);
} else
// Use the original instruction. NOTE: We can only do this because `iter()` is
// currently just passing-through its argument.
processInstructionInterprocedurally(
ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis);
}
}
}
}
} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
Expand All @@ -152,8 +189,7 @@ private static Set<PointsToSetVariable> getDataflowSources(
src,
sources,
callGraph,
pointerAnalysis,
newHashSet());
pointerAnalysis);
}
} else if (inst instanceof PythonPropertyRead) {
// We are potentially pulling a tensor out of a non-scalar tensor iterable.
Expand All @@ -172,21 +208,42 @@ private static Set<PointsToSetVariable> getDataflowSources(
|| def instanceof PythonPropertyRead
|| def instanceof PythonInvokeInstruction) {
processInstruction(
def,
du,
localPointerKeyNode,
src,
sources,
callGraph,
pointerAnalysis,
newHashSet());
def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis);
}
}
}
}
return sources;
}

/**
* Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable}
* is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources.
*
* @param instruction The {@link SSAInstruction} to process.
* @param du The {@link DefUse} corresponding to the given {@link SSAInstruction}.
* @param node The {@link CGNode} containing the given {@link SSAInstruction}.
* @param src The {@link PointsToSetVariable} under question as to whether it should be considered
* a tensor dataflow source.
* @param sources The {@link Set} of tensor dataflow sources.
* @param callGraph The {@link CallGraph} containing the given {@link SSAInstruction}.
* @param pointerAnalysis The {@link PointerAnalysis} corresponding to the given {@link
* CallGraph}.
* @return True iff the given {@link PointsToSetVariable} was added to the given {@link Set} of
* {@link PointsToSetVariable} dataflow sources.
*/
private static boolean processInstruction(
SSAInstruction instruction,
DefUse du,
CGNode node,
PointsToSetVariable src,
Set<PointsToSetVariable> sources,
CallGraph callGraph,
PointerAnalysis<InstanceKey> pointerAnalysis) {
return processInstruction(
instruction, du, node, src, sources, callGraph, pointerAnalysis, newHashSet());
}

/**
* Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable}
* is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources.
Expand Down Expand Up @@ -270,7 +327,7 @@ private static boolean processInstructionInterprocedurally(
PointerAnalysis<InstanceKey> pointerAnalysis) {
logger.info(
() ->
"Using interprocedural analysis to find potential tensor iterable definition for use: "
"Using interprocedural analysis to find potential tensor definition for use: "
+ use
+ " of instruction: "
+ instruction
Expand Down Expand Up @@ -300,7 +357,7 @@ private static boolean processInstructionInterprocedurally(
* Returns true iff the given {@link PointsToSetVariable} refers to a tensor dataset element of
* the dataset defined by the given value number in the given {@link CGNode}.
*
* @param src The {@link PointsToSetVariable} to consider.
* @param variable The {@link PointsToSetVariable} to consider.
* @param val The value in the given {@link CGNode} representing the tensor dataset.
* @param node The {@link CGNode} containing the given {@link PointsToSetVariable} and value.
* @param pointerAnalysis The {@link PointerAnalysis} that includes points-to information for the
Expand All @@ -309,7 +366,10 @@ private static boolean processInstructionInterprocedurally(
* val in node.
*/
private static boolean isDatasetTensorElement(
PointsToSetVariable src, int val, CGNode node, PointerAnalysis<InstanceKey> pointerAnalysis) {
PointsToSetVariable variable,
int val,
CGNode node,
PointerAnalysis<InstanceKey> pointerAnalysis) {
SSAInstruction def = node.getDU().getDef(val);

if (def instanceof PythonInvokeInstruction) {
Expand All @@ -335,7 +395,8 @@ private static boolean isDatasetTensorElement(

PythonPropertyRead srcDef =
(PythonPropertyRead)
node.getDU().getDef(((LocalPointerKey) src.getPointerKey()).getValueNumber());
node.getDU()
.getDef(((LocalPointerKey) variable.getPointerKey()).getValueNumber());

// What does the member reference point to?
PointerKey memberRefPointerKey =
Expand Down
15 changes: 15 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf


def add(a, b):
return a + b


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

my_iter = iter(dataset)
length = len(dataset)

for _ in range(length):
element = next(my_iter)
add(element, element)
19 changes: 19 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset17.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import tensorflow as tf


def add(a, b):
return a + b


def f(a):
return add(a, a)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

my_iter = iter(dataset)
length = len(dataset)

for _ in range(length):
element = next(my_iter)
f(element)
23 changes: 23 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import tensorflow as tf


def add(a, b):
return a + b


def f(a):
return add(a, a)


def g(a):
element = next(a)
return f(element)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])

my_iter = iter(dataset)
length = len(dataset)

for _ in range(length):
g(my_iter)
32 changes: 32 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset19.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# From https://github.com/YunYang1994/TensorFlow2.0-Examples/blob/299fd6689f242d0f647a96b8844e86325e9fcb46/7-Utils/multi_gpu_train.py.

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator


@tf.function
def distributed_train_step(dataset_inputs):
pass


EPOCHS = 40
IMG_SIZE = 112 # Input Image Size
BATCH_SIZE = 512 # Total 4 GPU, 128 batch per GPU

train_datagen = ImageDataGenerator(
rescale=1.0 / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=False
)

train_generator = train_datagen.flow_from_directory(
"./mnist/train",
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode="categorical",
)

for epoch in range(1, EPOCHS + 1):
batchs_per_epoch = len(train_generator)
train_dataset = iter(train_generator)

for _ in range(batchs_per_epoch):
batch_loss = distributed_train_step(next(train_dataset))
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey;
import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions.BuiltinFunction;
import com.ibm.wala.cast.python.ir.PythonLanguage;
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
Expand Down Expand Up @@ -40,7 +39,6 @@
import com.ibm.wala.ssa.SymbolTable;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.collections.HashMapFactory;
import com.ibm.wala.util.collections.Pair;
import com.ibm.wala.util.intset.IntIterator;
import com.ibm.wala.util.intset.IntSet;
Expand All @@ -49,7 +47,6 @@
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.logging.Logger;

public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder {
Expand Down Expand Up @@ -98,9 +95,6 @@ protected boolean sameMethod(CGNode opNode, String definingMethod) {
private static final Collection<TypeReference> types =
Arrays.asList(PythonTypes.string, TypeReference.Int);

private final Map<Pair<String, TypeReference>, BuiltinFunction> primitives =
HashMapFactory.make();

public static class PythonConstraintVisitor extends AstConstraintVisitor
implements PythonInstructionVisitor {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ private static TypeReference builtinFunction(String name) {
builtinFunctions.put("__delete__", Either.forRight(2));
// https://docs.python.org/3/library/functions.html#print
builtinFunctions.put("print", Either.forLeft(TypeReference.Void));
// https://docs.python.org/3/library/functions.html#iter
builtinFunctions.put("iter", Either.forRight(2));
// https://docs.python.org/3/library/functions.html#next
builtinFunctions.put("next", Either.forLeft(PythonTypes.object));
}

public static Set<String> builtins() {
Expand Down
Loading
Loading