-
Notifications
You must be signed in to change notification settings - Fork 747
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Map
torch::data::datasets::DistributedSampler
and `StreamSampler…
…` from PyTorch (issue #1215)
- Loading branch information
Showing
10 changed files
with
422 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** A wrapper around a batch size value, which implements the | ||
* {@code CustomBatchRequest} interface. */ | ||
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class BatchSize extends CustomBatchRequest { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public BatchSize(Pointer p) { super(p); } | ||
|
||
public BatchSize(@Cast("size_t") long size) { super((Pointer)null); allocate(size); } | ||
private native void allocate(@Cast("size_t") long size); | ||
public native @Cast("size_t") @NoException(true) long size(); | ||
public native @Cast("size_t") @Name("operator size_t") @NoException(true) long asLong(); | ||
public native @Cast("size_t") long size_(); public native BatchSize size_(long setter); | ||
} |
32 changes: 32 additions & 0 deletions
32
pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
@NoOffset @Name("c10::optional<torch::data::samplers::BatchSize>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class BatchSizeOptional extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public BatchSizeOptional(Pointer p) { super(p); } | ||
public BatchSizeOptional(BatchSize value) { this(); put(value); } | ||
public BatchSizeOptional() { allocate(); } | ||
private native void allocate(); | ||
public native @Name("operator =") @ByRef BatchSizeOptional put(@ByRef BatchSizeOptional x); | ||
|
||
public native boolean has_value(); | ||
public native @Name("value") @ByRef BatchSize get(); | ||
@ValueSetter public native BatchSizeOptional put(@ByRef BatchSize value); | ||
} | ||
|
39 changes: 39 additions & 0 deletions
39
pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
@Name("torch::data::samplers::Sampler<torch::data::samplers::BatchSize>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class BatchSizeSampler extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public BatchSizeSampler(Pointer p) { super(p); } | ||
|
||
|
||
/** Resets the {@code Sampler}'s internal state. | ||
* Typically called before a new epoch. | ||
* Optionally, accepts a new size when reseting the sampler. */ | ||
public native void reset(@ByVal SizeTOptional new_size); | ||
|
||
/** Returns the next index if possible, or an empty optional if the | ||
* sampler is exhausted for this epoch. */ | ||
public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size); | ||
|
||
/** Serializes the {@code Sampler} to the {@code archive}. */ | ||
public native void save(@ByRef OutputArchive archive); | ||
|
||
/** Deserializes the {@code Sampler} from the {@code archive}. */ | ||
public native void load(@ByRef InputArchive archive); | ||
} |
28 changes: 28 additions & 0 deletions
28
pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
/** A base class for custom index types. */ | ||
@Namespace("torch::data::samplers") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class CustomBatchRequest extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public CustomBatchRequest(Pointer p) { super(p); } | ||
|
||
|
||
/** The number of elements accessed by this index. */ | ||
public native @Cast("size_t") long size(); | ||
} |
57 changes: 57 additions & 0 deletions
57
pytorch/src/gen/java/org/bytedeco/pytorch/DistributedRandomSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** Select samples randomly. The sampling order is shuffled at each {@code reset()} | ||
* call. */ | ||
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class DistributedRandomSampler extends DistributedSampler { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public DistributedRandomSampler(Pointer p) { super(p); } | ||
|
||
public DistributedRandomSampler( | ||
@Cast("size_t") long size, | ||
@Cast("size_t") long num_replicas/*=1*/, | ||
@Cast("size_t") long rank/*=0*/, | ||
@Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); } | ||
private native void allocate( | ||
@Cast("size_t") long size, | ||
@Cast("size_t") long num_replicas/*=1*/, | ||
@Cast("size_t") long rank/*=0*/, | ||
@Cast("bool") boolean allow_duplicates/*=true*/); | ||
public DistributedRandomSampler( | ||
@Cast("size_t") long size) { super((Pointer)null); allocate(size); } | ||
private native void allocate( | ||
@Cast("size_t") long size); | ||
|
||
/** Resets the {@code DistributedRandomSampler} to a new set of indices. */ | ||
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size); | ||
public native void reset(); | ||
|
||
/** Returns the next batch of indices. */ | ||
public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size); | ||
|
||
/** Serializes the {@code DistributedRandomSampler} to the {@code archive}. */ | ||
public native void save(@ByRef OutputArchive archive); | ||
|
||
/** Deserializes the {@code DistributedRandomSampler} from the {@code archive}. */ | ||
public native void load(@ByRef InputArchive archive); | ||
|
||
/** Returns the current index of the {@code DistributedRandomSampler}. */ | ||
public native @Cast("size_t") @NoException(true) long index(); | ||
} |
36 changes: 36 additions & 0 deletions
36
pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** A {@code Sampler} that selects a subset of indices to sample from and defines a | ||
* sampling behavior. In a distributed setting, this selects a subset of the | ||
* indices depending on the provided num_replicas and rank parameters. The | ||
* {@code Sampler} performs a rounding operation based on the {@code allow_duplicates} | ||
* parameter to decide the local sample count. */ | ||
@Name("torch::data::samplers::DistributedSampler<std::vector<size_t> >") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class DistributedSampler extends Sampler { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public DistributedSampler(Pointer p) { super(p); } | ||
|
||
|
||
/** Set the epoch for the current enumeration. This can be used to alter the | ||
* sample selection and shuffling behavior. */ | ||
public native void set_epoch(@Cast("size_t") long epoch); | ||
|
||
public native @Cast("size_t") long epoch(); | ||
} |
56 changes: 56 additions & 0 deletions
56
pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSequentialSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** Select samples sequentially. */ | ||
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class DistributedSequentialSampler extends DistributedSampler { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public DistributedSequentialSampler(Pointer p) { super(p); } | ||
|
||
public DistributedSequentialSampler( | ||
@Cast("size_t") long size, | ||
@Cast("size_t") long num_replicas/*=1*/, | ||
@Cast("size_t") long rank/*=0*/, | ||
@Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); } | ||
private native void allocate( | ||
@Cast("size_t") long size, | ||
@Cast("size_t") long num_replicas/*=1*/, | ||
@Cast("size_t") long rank/*=0*/, | ||
@Cast("bool") boolean allow_duplicates/*=true*/); | ||
public DistributedSequentialSampler( | ||
@Cast("size_t") long size) { super((Pointer)null); allocate(size); } | ||
private native void allocate( | ||
@Cast("size_t") long size); | ||
|
||
/** Resets the {@code DistributedSequentialSampler} to a new set of indices. */ | ||
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size); | ||
public native void reset(); | ||
|
||
/** Returns the next batch of indices. */ | ||
public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size); | ||
|
||
/** Serializes the {@code DistributedSequentialSampler} to the {@code archive}. */ | ||
public native void save(@ByRef OutputArchive archive); | ||
|
||
/** Deserializes the {@code DistributedSequentialSampler} from the {@code archive}. */ | ||
public native void load(@ByRef InputArchive archive); | ||
|
||
/** Returns the current index of the {@code DistributedSequentialSampler}. */ | ||
public native @Cast("size_t") @NoException(true) long index(); | ||
} |
50 changes: 50 additions & 0 deletions
50
pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE | ||
|
||
package org.bytedeco.pytorch; | ||
|
||
import org.bytedeco.pytorch.Allocator; | ||
import org.bytedeco.pytorch.Function; | ||
import org.bytedeco.pytorch.Module; | ||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.bytedeco.javacpp.presets.javacpp.*; | ||
import static org.bytedeco.openblas.global.openblas_nolapack.*; | ||
import static org.bytedeco.openblas.global.openblas.*; | ||
|
||
import static org.bytedeco.pytorch.global.torch.*; | ||
|
||
|
||
/** A sampler for (potentially infinite) streams of data. | ||
* | ||
* The major feature of the {@code StreamSampler} is that it does not return | ||
* particular indices, but instead only the number of elements to fetch from | ||
* the dataset. The dataset has to decide how to produce those elements. */ | ||
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class) | ||
public class StreamSampler extends BatchSizeSampler { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public StreamSampler(Pointer p) { super(p); } | ||
|
||
/** Constructs the {@code StreamSampler} with the number of individual examples that | ||
* should be fetched until the sampler is exhausted. */ | ||
public StreamSampler(@Cast("size_t") long epoch_size) { super((Pointer)null); allocate(epoch_size); } | ||
private native void allocate(@Cast("size_t") long epoch_size); | ||
|
||
/** Resets the internal state of the sampler. */ | ||
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size); | ||
public native void reset(); | ||
|
||
/** Returns a {@code BatchSize} object with the number of elements to fetch in the | ||
* next batch. This number is the minimum of the supplied {@code batch_size} and | ||
* the difference between the {@code epoch_size} and the current index. If the | ||
* {@code epoch_size} has been reached, returns an empty optional. */ | ||
public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size); | ||
|
||
/** Serializes the {@code StreamSampler} to the {@code archive}. */ | ||
public native void save(@ByRef OutputArchive archive); | ||
|
||
/** Deserializes the {@code StreamSampler} from the {@code archive}. */ | ||
public native void load(@ByRef InputArchive archive); | ||
} |
Oops, something went wrong.