From c0d15db924e1c10de7e6999ac2d105ff479e5754 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Tue, 13 Dec 2022 10:27:16 +0900 Subject: [PATCH] * Map `torch::data::datasets::DistributedSampler` and `StreamSampler` from PyTorch (issue #1215) --- .../java/org/bytedeco/pytorch/BatchSize.java | 32 +++++++ .../bytedeco/pytorch/BatchSizeOptional.java | 32 +++++++ .../bytedeco/pytorch/BatchSizeSampler.java | 39 +++++++++ .../bytedeco/pytorch/CustomBatchRequest.java | 28 +++++++ .../pytorch/DistributedRandomSampler.java | 57 +++++++++++++ .../bytedeco/pytorch/DistributedSampler.java | 36 ++++++++ .../pytorch/DistributedSequentialSampler.java | 56 +++++++++++++ .../org/bytedeco/pytorch/StreamSampler.java | 50 +++++++++++ .../org/bytedeco/pytorch/global/torch.java | 84 +++++++++++++++++++ .../org/bytedeco/pytorch/presets/torch.java | 13 +-- 10 files changed, 422 insertions(+), 5 deletions(-) create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/BatchSize.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/DistributedRandomSampler.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSequentialSampler.java create mode 100644 pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSize.java b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSize.java new file mode 100644 index 00000000000..b6e4526b26c --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSize.java @@ -0,0 +1,32 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + + +/** A wrapper around a batch size value, which implements the + * {@code CustomBatchRequest} interface. */ +@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class BatchSize extends CustomBatchRequest { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BatchSize(Pointer p) { super(p); } + + public BatchSize(@Cast("size_t") long size) { super((Pointer)null); allocate(size); } + private native void allocate(@Cast("size_t") long size); + public native @Cast("size_t") @NoException(true) long size(); + public native @Cast("size_t") @Name("operator size_t") @NoException(true) long asLong(); + public native @Cast("size_t") long size_(); public native BatchSize size_(long setter); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java new file mode 100644 index 00000000000..fa2c7b2e349 --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java @@ -0,0 +1,32 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + +@NoOffset @Name("c10::optional") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class BatchSizeOptional extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BatchSizeOptional(Pointer p) { super(p); } + public BatchSizeOptional(BatchSize value) { this(); put(value); } + public BatchSizeOptional() { allocate(); } + private native void allocate(); + public native @Name("operator =") @ByRef BatchSizeOptional put(@ByRef BatchSizeOptional x); + + public native boolean has_value(); + public native @Name("value") @ByRef BatchSize get(); + @ValueSetter public native BatchSizeOptional put(@ByRef BatchSize value); +} + diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java new file mode 100644 index 00000000000..dbf766ff925 --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java @@ -0,0 +1,39 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + +@Name("torch::data::samplers::Sampler") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class BatchSizeSampler extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public BatchSizeSampler(Pointer p) { super(p); } + + + /** Resets the {@code Sampler}'s internal state. + * Typically called before a new epoch. + * Optionally, accepts a new size when reseting the sampler. */ + public native void reset(@ByVal SizeTOptional new_size); + + /** Returns the next index if possible, or an empty optional if the + * sampler is exhausted for this epoch. */ + public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size); + + /** Serializes the {@code Sampler} to the {@code archive}. */ + public native void save(@ByRef OutputArchive archive); + + /** Deserializes the {@code Sampler} from the {@code archive}. */ + public native void load(@ByRef InputArchive archive); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java b/pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java new file mode 100644 index 00000000000..02a2000c22a --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java @@ -0,0 +1,28 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + +/** A base class for custom index types. */ +@Namespace("torch::data::samplers") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class CustomBatchRequest extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public CustomBatchRequest(Pointer p) { super(p); } + + + /** The number of elements accessed by this index. */ + public native @Cast("size_t") long size(); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedRandomSampler.java b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedRandomSampler.java new file mode 100644 index 00000000000..61b5d60af3d --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedRandomSampler.java @@ -0,0 +1,57 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + + +/** Select samples randomly. The sampling order is shuffled at each {@code reset()} + * call. */ +@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class DistributedRandomSampler extends DistributedSampler { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DistributedRandomSampler(Pointer p) { super(p); } + + public DistributedRandomSampler( + @Cast("size_t") long size, + @Cast("size_t") long num_replicas/*=1*/, + @Cast("size_t") long rank/*=0*/, + @Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); } + private native void allocate( + @Cast("size_t") long size, + @Cast("size_t") long num_replicas/*=1*/, + @Cast("size_t") long rank/*=0*/, + @Cast("bool") boolean allow_duplicates/*=true*/); + public DistributedRandomSampler( + @Cast("size_t") long size) { super((Pointer)null); allocate(size); } + private native void allocate( + @Cast("size_t") long size); + + /** Resets the {@code DistributedRandomSampler} to a new set of indices. */ + public native void reset(@ByVal(nullValue = "c10::optional(c10::nullopt)") SizeTOptional new_size); + public native void reset(); + + /** Returns the next batch of indices. */ + public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size); + + /** Serializes the {@code DistributedRandomSampler} to the {@code archive}. */ + public native void save(@ByRef OutputArchive archive); + + /** Deserializes the {@code DistributedRandomSampler} from the {@code archive}. */ + public native void load(@ByRef InputArchive archive); + + /** Returns the current index of the {@code DistributedRandomSampler}. */ + public native @Cast("size_t") @NoException(true) long index(); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java new file mode 100644 index 00000000000..9a1460d271e --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java @@ -0,0 +1,36 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + + +/** A {@code Sampler} that selects a subset of indices to sample from and defines a + * sampling behavior. In a distributed setting, this selects a subset of the + * indices depending on the provided num_replicas and rank parameters. The + * {@code Sampler} performs a rounding operation based on the {@code allow_duplicates} + * parameter to decide the local sample count. */ +@Name("torch::data::samplers::DistributedSampler >") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class DistributedSampler extends Sampler { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DistributedSampler(Pointer p) { super(p); } + + + /** Set the epoch for the current enumeration. This can be used to alter the + * sample selection and shuffling behavior. */ + public native void set_epoch(@Cast("size_t") long epoch); + + public native @Cast("size_t") long epoch(); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSequentialSampler.java b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSequentialSampler.java new file mode 100644 index 00000000000..0260dc09e00 --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSequentialSampler.java @@ -0,0 +1,56 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + + +/** Select samples sequentially. */ +@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class DistributedSequentialSampler extends DistributedSampler { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public DistributedSequentialSampler(Pointer p) { super(p); } + + public DistributedSequentialSampler( + @Cast("size_t") long size, + @Cast("size_t") long num_replicas/*=1*/, + @Cast("size_t") long rank/*=0*/, + @Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); } + private native void allocate( + @Cast("size_t") long size, + @Cast("size_t") long num_replicas/*=1*/, + @Cast("size_t") long rank/*=0*/, + @Cast("bool") boolean allow_duplicates/*=true*/); + public DistributedSequentialSampler( + @Cast("size_t") long size) { super((Pointer)null); allocate(size); } + private native void allocate( + @Cast("size_t") long size); + + /** Resets the {@code DistributedSequentialSampler} to a new set of indices. */ + public native void reset(@ByVal(nullValue = "c10::optional(c10::nullopt)") SizeTOptional new_size); + public native void reset(); + + /** Returns the next batch of indices. */ + public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size); + + /** Serializes the {@code DistributedSequentialSampler} to the {@code archive}. */ + public native void save(@ByRef OutputArchive archive); + + /** Deserializes the {@code DistributedSequentialSampler} from the {@code archive}. */ + public native void load(@ByRef InputArchive archive); + + /** Returns the current index of the {@code DistributedSequentialSampler}. */ + public native @Cast("size_t") @NoException(true) long index(); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java b/pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java new file mode 100644 index 00000000000..b59d0fe0dbf --- /dev/null +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java @@ -0,0 +1,50 @@ +// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE + +package org.bytedeco.pytorch; + +import org.bytedeco.pytorch.Allocator; +import org.bytedeco.pytorch.Function; +import org.bytedeco.pytorch.Module; +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.bytedeco.javacpp.presets.javacpp.*; +import static org.bytedeco.openblas.global.openblas_nolapack.*; +import static org.bytedeco.openblas.global.openblas.*; + +import static org.bytedeco.pytorch.global.torch.*; + + +/** A sampler for (potentially infinite) streams of data. + * + * The major feature of the {@code StreamSampler} is that it does not return + * particular indices, but instead only the number of elements to fetch from + * the dataset. The dataset has to decide how to produce those elements. */ +@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) +public class StreamSampler extends BatchSizeSampler { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public StreamSampler(Pointer p) { super(p); } + + /** Constructs the {@code StreamSampler} with the number of individual examples that + * should be fetched until the sampler is exhausted. */ + public StreamSampler(@Cast("size_t") long epoch_size) { super((Pointer)null); allocate(epoch_size); } + private native void allocate(@Cast("size_t") long epoch_size); + + /** Resets the internal state of the sampler. */ + public native void reset(@ByVal(nullValue = "c10::optional(c10::nullopt)") SizeTOptional new_size); + public native void reset(); + + /** Returns a {@code BatchSize} object with the number of elements to fetch in the + * next batch. This number is the minimum of the supplied {@code batch_size} and + * the difference between the {@code epoch_size} and the current index. If the + * {@code epoch_size} has been reached, returns an empty optional. */ + public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size); + + /** Serializes the {@code StreamSampler} to the {@code archive}. */ + public native void save(@ByRef OutputArchive archive); + + /** Deserializes the {@code StreamSampler} from the {@code archive}. */ + public native void load(@ByRef InputArchive archive); +} diff --git a/pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java b/pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java index b2416a6b11c..ac974ded4bb 100644 --- a/pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java +++ b/pytorch/src/gen/java/org/bytedeco/pytorch/global/torch.java @@ -201,6 +201,9 @@ public class torch extends org.bytedeco.pytorch.presets.torch { // Targeting ../ExampleOptional.java +// Targeting ../BatchSizeOptional.java + + // Targeting ../TensorTensorOptional.java @@ -75137,6 +75140,49 @@ The list of (type, depth) pairs controls the type of specializations and the num // Targeting ../Sampler.java +// Targeting ../BatchSizeSampler.java + + + + // namespace samplers + // namespace data + // namespace torch + + +// Parsed from torch/data/samplers/custom_batch_request.h + +// #pragma once + +// #include +// #include +// Targeting ../CustomBatchRequest.java + + + // namespace samplers + // namespace data + // namespace torch + + +// Parsed from torch/data/samplers/distributed.h + +// #pragma once + +// #include +// #include + +// #include +// #include + // namespace serialize + // namespace torch +// Targeting ../DistributedSampler.java + + +// Targeting ../DistributedRandomSampler.java + + +// Targeting ../DistributedSequentialSampler.java + + // namespace samplers // namespace data @@ -75179,6 +75225,44 @@ The list of (type, depth) pairs controls the type of specializations and the num + // namespace samplers + // namespace data + // namespace torch + + +// Parsed from torch/data/samplers/serialize.h + +// #pragma once + +// #include +// #include +/** Serializes a {@code Sampler} into an {@code OutputArchive}. */ + +/** Deserializes a {@code Sampler} from an {@code InputArchive}. */ + // namespace samplers + // namespace data + // namespace torch + + +// Parsed from torch/data/samplers/stream.h + +// #pragma once + +// #include +// #include +// #include +// #include + +// #include + // namespace serialize + // namespace torch +// Targeting ../BatchSize.java + + +// Targeting ../StreamSampler.java + + + // namespace samplers // namespace data // namespace torch diff --git a/pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java b/pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java index 24d33bdbe0f..a8b1a5893d8 100644 --- a/pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java +++ b/pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java @@ -1611,12 +1611,12 @@ "torch/data/datasets/tensor.h", "torch/data/samplers.h", "torch/data/samplers/base.h", -// "torch/data/samplers/custom_batch_request.h", -// "torch/data/samplers/distributed.h", + "torch/data/samplers/custom_batch_request.h", + "torch/data/samplers/distributed.h", "torch/data/samplers/random.h", "torch/data/samplers/sequential.h", -// "torch/data/samplers/serialize.h", -// "torch/data/samplers/stream.h", + "torch/data/samplers/serialize.h", + "torch/data/samplers/stream.h", "torch/data/transforms.h", "torch/data/transforms/base.h", "torch/data/transforms/collate.h", @@ -2548,11 +2548,14 @@ public void map(InfoMap infoMap) { .put(new Info("torch::data::Iterator > >").purify().pointerTypes("ExampleVectorIterator")) .put(new Info("torch::data::Iterator > > >").purify().pointerTypes("ExampleVectorOptionalIterator")) .put(new Info("torch::data::samplers::Sampler >", "torch::data::samplers::Sampler<>").pointerTypes("Sampler")) + .put(new Info("torch::data::samplers::Sampler").pointerTypes("BatchSizeSampler")) + .put(new Info("torch::data::samplers::RandomSampler").pointerTypes("RandomSampler")) + .put(new Info("torch::data::samplers::DistributedSampler >", "torch::data::samplers::DistributedSampler<>").purify().pointerTypes("DistributedSampler")) + .put(new Info("c10::optional").pointerTypes("BatchSizeOptional").define()) .put(new Info("torch::data::transforms::BatchTransform >, torch::data::Example<> >", "torch::data::transforms::Collation >").pointerTypes("ExampleCollation")) .put(new Info("torch::data::transforms::Stack >").pointerTypes("ExampleStack")) - .put(new Info("torch::data::samplers::RandomSampler").pointerTypes("RandomSampler")) .put(new Info("torch::data::datasets::ChunkDataReader,std::vector > >", VirtualChunkDataReader).pointerTypes("ChunkDataReader").virtualize()) .put(new Info("torch::data::datasets::ChunkDataset<" + VirtualChunkDataReader + ",torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>").pointerTypes("ChunkDataset")) .put(new Info("torch::data::datasets::ChunkDataset<" + VirtualChunkDataReader + ",torch::data::samplers::RandomSampler,torch::data::samplers::RandomSampler>::ChunkDataset").javaText(