Skip to content

Commit

Permalink
Remove extra generics from op generation (#193)
Browse files Browse the repository at this point in the history
* Successfully remove extra type params, but it broke javadoc generation

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Generate covariant types

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Do generation

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Update help text.

Signed-off-by: Ryan Nett <rnett@calpoly.edu>

* Fixes

Signed-off-by: Ryan Nett <rnett@calpoly.edu>
  • Loading branch information
rnett authored Jan 26, 2021
1 parent b54526c commit 5e4b214
Show file tree
Hide file tree
Showing 299 changed files with 797 additions and 978 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const char kUsageHeader[] =
"provided list of libraries. A wrapper exposes an intuitive and\n"
"strongly-typed interface for building its underlying operation and linking "
"it into a graph.\n\n"
"The first argument is the location of the tensorflow binary built for TF-"
"Java.\nFor example, `bazel-out/k8-opt/bin/external/org_tensorflow/tensorfl"
"ow/libtensorflow_cc.so`.\n\n"
"Operation wrappers are generated under the path specified by the "
"'--output_dir' argument. This path can be absolute or relative to the\n"
"current working directory and will be created if it does not exist.\n\n"
Expand All @@ -45,7 +48,9 @@ const char kUsageHeader[] =
"Finally, the `--api_dirs` argument takes a list of comma-separated "
"directories of API definitions can be provided to override default\n"
"values found in the ops definitions. Directories are ordered by priority "
"(the last having precedence over the first).\n\n";
"(the last having precedence over the first).\nFor example, `bazel-tensorf"
"low-core-api/external/org_tensorflow/tensorflow/core/api_def/base_api,src"
"/bazel/api_def`\n\n";

} // namespace java
} // namespace tensorflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,63 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
for (const auto& endpoint_def : api_def.endpoint()) {
op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
}
op.RemoveExtraGenerics();
return op;
}

void OpSpec::RemoveExtraGenerics() {
std::map<string, int> generics;

for (const ArgumentSpec& output : this->outputs()) {
if (output.type().kind() == Type::GENERIC && !output.type().wildcard()) {
if (generics.find(output.type().name()) == generics.end()) {
generics[output.type().name()] = 1;
} else {
generics[output.type().name()] = generics.find(output.type().name())->second + 1;
}
}
}

for (const ArgumentSpec& input : this->inputs()) {
if (input.type().kind() == Type::GENERIC && !input.type().wildcard()) {
if (generics.find(input.type().name()) == generics.end()) {
generics[input.type().name()] = 1;
} else {
generics[input.type().name()] = generics.find(input.type().name())->second + 1;
}
}
}

for (ArgumentSpec& output : this->outputs_) {
if (output.type().kind() == Type::GENERIC && !output.type().wildcard()) {
if (generics[output.type().name()] <= 1) {
output.toUpperBound();
}
}
}

for (ArgumentSpec& input : this->inputs_) {
if (generics[input.type().name()] <= 1) {
input.toUpperBound();
}
}
}

void ArgumentSpec::toUpperBound() {
if(this->type().kind() == Type::GENERIC && this->var().type().name() == "Operand" &&
this->type().supertypes().size() == 1){
Type newType = Type::Wildcard().add_supertype(this->type().supertypes().front());
Type varType = Type::Interface("Operand", "org.tensorflow").add_parameter(newType);

if(this->var().variadic()){
this->var_ = Variable::Varargs(this->var().name(), varType);
} else {
this->var_ = Variable::Create(this->var().name(), varType);
}

this->type_ = newType;
}
}

} // namespace java
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ class ArgumentSpec {
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
const Type& type() const { return type_; }
void toUpperBound();
const string& description() const { return description_; }
bool iterable() const { return iterable_; }

private:
const string op_def_name_;
const Variable var_;
const Type type_;
Variable var_;
Type type_;
const string description_;
const bool iterable_;
};
Expand Down Expand Up @@ -164,6 +165,8 @@ class OpSpec {
hidden_(hidden),
deprecation_explanation_(deprecation_explanation) {}

void RemoveExtraGenerics();

const string graph_op_name_;
const bool hidden_;
const string deprecation_explanation_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public final class DtypesOps {
* @param options carries optional attributes values
* @return a new instance of AsString
*/
public <T extends TType> AsString asString(Operand<T> input, AsString.Options... options) {
public AsString asString(Operand<? extends TType> input, AsString.Options... options) {
return AsString.create(scope, input, options);
}

Expand All @@ -72,7 +72,7 @@ public <T extends TType> AsString asString(Operand<T> input, AsString.Options...
* @param options carries optional attributes values
* @return a new instance of Cast
*/
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT,
public <U extends TType> Cast<U> cast(Operand<? extends TType> x, Class<U> DstT,
Cast.Options... options) {
return Cast.create(scope, x, DstT, options);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public CombinedNonMaxSuppression combinedNonMaxSuppression(Operand<TFloat32> box
* @param options carries optional attributes values
* @return a new instance of CropAndResize
*/
public <T extends TNumber> CropAndResize cropAndResize(Operand<T> image, Operand<TFloat32> boxes,
public CropAndResize cropAndResize(Operand<? extends TNumber> image, Operand<TFloat32> boxes,
Operand<TInt32> boxInd, Operand<TInt32> cropSize, CropAndResize.Options... options) {
return CropAndResize.create(scope, image, boxes, boxInd, cropSize, options);
}
Expand All @@ -239,8 +239,8 @@ public <T extends TNumber> CropAndResize cropAndResize(Operand<T> image, Operand
* @param options carries optional attributes values
* @return a new instance of CropAndResizeGradBoxes
*/
public <T extends TNumber> CropAndResizeGradBoxes cropAndResizeGradBoxes(Operand<TFloat32> grads,
Operand<T> image, Operand<TFloat32> boxes, Operand<TInt32> boxInd,
public CropAndResizeGradBoxes cropAndResizeGradBoxes(Operand<TFloat32> grads,
Operand<? extends TNumber> image, Operand<TFloat32> boxes, Operand<TInt32> boxInd,
CropAndResizeGradBoxes.Options... options) {
return CropAndResizeGradBoxes.create(scope, grads, image, boxes, boxInd, options);
}
Expand Down Expand Up @@ -573,7 +573,7 @@ public EncodeJpegVariableQuality encodeJpegVariableQuality(Operand<TUint8> image
* @param options carries optional attributes values
* @return a new instance of EncodePng
*/
public <T extends TNumber> EncodePng encodePng(Operand<T> image, EncodePng.Options... options) {
public EncodePng encodePng(Operand<? extends TNumber> image, EncodePng.Options... options) {
return EncodePng.create(scope, image, options);
}

Expand Down Expand Up @@ -791,7 +791,7 @@ public <T extends TNumber> RandomCrop<T> randomCrop(Operand<T> image, Operand<TI
* @param options carries optional attributes values
* @return a new instance of ResizeArea
*/
public <T extends TNumber> ResizeArea resizeArea(Operand<T> images, Operand<TInt32> size,
public ResizeArea resizeArea(Operand<? extends TNumber> images, Operand<TInt32> size,
ResizeArea.Options... options) {
return ResizeArea.create(scope, images, size, options);
}
Expand All @@ -807,7 +807,7 @@ public <T extends TNumber> ResizeArea resizeArea(Operand<T> images, Operand<TInt
* @param options carries optional attributes values
* @return a new instance of ResizeBicubic
*/
public <T extends TNumber> ResizeBicubic resizeBicubic(Operand<T> images, Operand<TInt32> size,
public ResizeBicubic resizeBicubic(Operand<? extends TNumber> images, Operand<TInt32> size,
ResizeBicubic.Options... options) {
return ResizeBicubic.create(scope, images, size, options);
}
Expand All @@ -823,7 +823,7 @@ public <T extends TNumber> ResizeBicubic resizeBicubic(Operand<T> images, Operan
* @param options carries optional attributes values
* @return a new instance of ResizeBilinear
*/
public <T extends TNumber> ResizeBilinear resizeBilinear(Operand<T> images, Operand<TInt32> size,
public ResizeBilinear resizeBilinear(Operand<? extends TNumber> images, Operand<TInt32> size,
ResizeBilinear.Options... options) {
return ResizeBilinear.create(scope, images, size, options);
}
Expand Down Expand Up @@ -939,7 +939,7 @@ public <T extends TNumber> SampleDistortedBoundingBox<T> sampleDistortedBounding
* @param options carries optional attributes values
* @return a new instance of ScaleAndTranslate
*/
public <T extends TNumber> ScaleAndTranslate scaleAndTranslate(Operand<T> images,
public ScaleAndTranslate scaleAndTranslate(Operand<? extends TNumber> images,
Operand<TInt32> size, Operand<TFloat32> scale, Operand<TFloat32> translation,
ScaleAndTranslate.Options... options) {
return ScaleAndTranslate.create(scope, images, size, scale, translation, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,8 @@ public ReaderSerializeState readerSerializeState(Operand<?> readerHandle) {
* @param sparseShape 1-D. The `shape` of the minibatch `SparseTensor`.
* @return a new instance of SerializeManySparse
*/
public <T extends TType> SerializeManySparse<TString> serializeManySparse(
Operand<TInt64> sparseIndices, Operand<T> sparseValues, Operand<TInt64> sparseShape) {
public SerializeManySparse<TString> serializeManySparse(Operand<TInt64> sparseIndices,
Operand<? extends TType> sparseValues, Operand<TInt64> sparseShape) {
return SerializeManySparse.create(scope, sparseIndices, sparseValues, sparseShape);
}

Expand All @@ -919,9 +919,8 @@ public <T extends TType> SerializeManySparse<TString> serializeManySparse(
* (default) and `variant`.
* @return a new instance of SerializeManySparse
*/
public <U extends TType, T extends TType> SerializeManySparse<U> serializeManySparse(
Operand<TInt64> sparseIndices, Operand<T> sparseValues, Operand<TInt64> sparseShape,
Class<U> outType) {
public <U extends TType> SerializeManySparse<U> serializeManySparse(Operand<TInt64> sparseIndices,
Operand<? extends TType> sparseValues, Operand<TInt64> sparseShape, Class<U> outType) {
return SerializeManySparse.create(scope, sparseIndices, sparseValues, sparseShape, outType);
}

Expand All @@ -934,8 +933,8 @@ public <U extends TType, T extends TType> SerializeManySparse<U> serializeManySp
* @param sparseShape 1-D. The `shape` of the `SparseTensor`.
* @return a new instance of SerializeSparse
*/
public <T extends TType> SerializeSparse<TString> serializeSparse(Operand<TInt64> sparseIndices,
Operand<T> sparseValues, Operand<TInt64> sparseShape) {
public SerializeSparse<TString> serializeSparse(Operand<TInt64> sparseIndices,
Operand<? extends TType> sparseValues, Operand<TInt64> sparseShape) {
return SerializeSparse.create(scope, sparseIndices, sparseValues, sparseShape);
}

Expand All @@ -950,9 +949,8 @@ public <T extends TType> SerializeSparse<TString> serializeSparse(Operand<TInt64
* (default) and `variant`.
* @return a new instance of SerializeSparse
*/
public <U extends TType, T extends TType> SerializeSparse<U> serializeSparse(
Operand<TInt64> sparseIndices, Operand<T> sparseValues, Operand<TInt64> sparseShape,
Class<U> outType) {
public <U extends TType> SerializeSparse<U> serializeSparse(Operand<TInt64> sparseIndices,
Operand<? extends TType> sparseValues, Operand<TInt64> sparseShape, Class<U> outType) {
return SerializeSparse.create(scope, sparseIndices, sparseValues, sparseShape, outType);
}

Expand All @@ -962,7 +960,7 @@ public <U extends TType, T extends TType> SerializeSparse<U> serializeSparse(
* @param tensor A Tensor of type `T`.
* @return a new instance of SerializeTensor
*/
public <T extends TType> SerializeTensor serializeTensor(Operand<T> tensor) {
public SerializeTensor serializeTensor(Operand<? extends TType> tensor) {
return SerializeTensor.create(scope, tensor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ public <T extends TNumber> CholeskyGrad<T> choleskyGrad(Operand<T> l, Operand<T>
* @param perm
* @return a new instance of ConjugateTranspose
*/
public <T extends TType, U extends TNumber> ConjugateTranspose<T> conjugateTranspose(Operand<T> x,
Operand<U> perm) {
public <T extends TType> ConjugateTranspose<T> conjugateTranspose(Operand<T> x,
Operand<? extends TNumber> perm) {
return ConjugateTranspose.create(scope, x, perm);
}

Expand Down Expand Up @@ -398,7 +398,7 @@ public <T extends TType> Det<T> det(Operand<T> input) {
* @param options carries optional attributes values
* @return a new instance of Eig
*/
public <U extends TType, T extends TType> Eig<U> eig(Operand<T> input, Class<U> Tout,
public <U extends TType> Eig<U> eig(Operand<? extends TType> input, Class<U> Tout,
Eig.Options... options) {
return Eig.create(scope, input, Tout, options);
}
Expand Down Expand Up @@ -505,8 +505,8 @@ public <T extends TType> Einsum<T> einsum(Iterable<Operand<T>> inputs, String eq
* @param options carries optional attributes values
* @return a new instance of EuclideanNorm
*/
public <T extends TType, U extends TNumber> EuclideanNorm<T> euclideanNorm(Operand<T> input,
Operand<U> axis, EuclideanNorm.Options... options) {
public <T extends TType> EuclideanNorm<T> euclideanNorm(Operand<T> input,
Operand<? extends TNumber> axis, EuclideanNorm.Options... options) {
return EuclideanNorm.create(scope, input, axis, options);
}

Expand Down Expand Up @@ -1373,10 +1373,10 @@ public <T extends TType> Qr<T> qr(Operand<T> input, Qr.Options... options) {
* @param options carries optional attributes values
* @return a new instance of QuantizedMatMul
*/
public <V extends TType, T extends TType, U extends TType, W extends TType> QuantizedMatMul<V> quantizedMatMul(
Operand<T> a, Operand<U> b, Operand<TFloat32> minA, Operand<TFloat32> maxA,
Operand<TFloat32> minB, Operand<TFloat32> maxB, Class<V> Toutput, Class<W> Tactivation,
QuantizedMatMul.Options... options) {
public <V extends TType, W extends TType> QuantizedMatMul<V> quantizedMatMul(
Operand<? extends TType> a, Operand<? extends TType> b, Operand<TFloat32> minA,
Operand<TFloat32> maxA, Operand<TFloat32> minB, Operand<TFloat32> maxB, Class<V> Toutput,
Class<W> Tactivation, QuantizedMatMul.Options... options) {
return QuantizedMatMul.create(scope, a, b, minA, maxA, minB, maxB, Toutput, Tactivation, options);
}

Expand Down Expand Up @@ -1544,8 +1544,7 @@ public <T extends TType> TensorDiagPart<T> tensorDiagPart(Operand<T> input) {
* @param perm
* @return a new instance of Transpose
*/
public <T extends TType, U extends TNumber> Transpose<T> transpose(Operand<T> x,
Operand<U> perm) {
public <T extends TType> Transpose<T> transpose(Operand<T> x, Operand<? extends TNumber> perm) {
return Transpose.create(scope, x, perm);
}

Expand Down
Loading

0 comments on commit 5e4b214

Please sign in to comment.