Skip to content

Commit

Permalink
* Map torch::data::datasets::DistributedSampler and `StreamSampler…
Browse files Browse the repository at this point in the history
…` from PyTorch (issue #1215)
  • Loading branch information
saudet committed Dec 13, 2022
1 parent fa4dfdc commit c0d15db
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 5 deletions.
32 changes: 32 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/BatchSize.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** A wrapper around a batch size value, which implements the
* {@code CustomBatchRequest} interface. */
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class BatchSize extends CustomBatchRequest {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BatchSize(Pointer p) { super(p); }

public BatchSize(@Cast("size_t") long size) { super((Pointer)null); allocate(size); }
private native void allocate(@Cast("size_t") long size);
public native @Cast("size_t") @NoException(true) long size();
public native @Cast("size_t") @Name("operator size_t") @NoException(true) long asLong();
public native @Cast("size_t") long size_(); public native BatchSize size_(long setter);
}
32 changes: 32 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeOptional.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@NoOffset @Name("c10::optional<torch::data::samplers::BatchSize>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class BatchSizeOptional extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BatchSizeOptional(Pointer p) { super(p); }
public BatchSizeOptional(BatchSize value) { this(); put(value); }
public BatchSizeOptional() { allocate(); }
private native void allocate();
public native @Name("operator =") @ByRef BatchSizeOptional put(@ByRef BatchSizeOptional x);

public native boolean has_value();
public native @Name("value") @ByRef BatchSize get();
@ValueSetter public native BatchSizeOptional put(@ByRef BatchSize value);
}

39 changes: 39 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/BatchSizeSampler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

@Name("torch::data::samplers::Sampler<torch::data::samplers::BatchSize>") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class BatchSizeSampler extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BatchSizeSampler(Pointer p) { super(p); }


/** Resets the {@code Sampler}'s internal state.
* Typically called before a new epoch.
* Optionally, accepts a new size when reseting the sampler. */
public native void reset(@ByVal SizeTOptional new_size);

/** Returns the next index if possible, or an empty optional if the
* sampler is exhausted for this epoch. */
public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size);

/** Serializes the {@code Sampler} to the {@code archive}. */
public native void save(@ByRef OutputArchive archive);

/** Deserializes the {@code Sampler} from the {@code archive}. */
public native void load(@ByRef InputArchive archive);
}
28 changes: 28 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/CustomBatchRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;

/** A base class for custom index types. */
@Namespace("torch::data::samplers") @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class CustomBatchRequest extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public CustomBatchRequest(Pointer p) { super(p); }


/** The number of elements accessed by this index. */
public native @Cast("size_t") long size();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** Select samples randomly. The sampling order is shuffled at each {@code reset()}
* call. */
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class DistributedRandomSampler extends DistributedSampler {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public DistributedRandomSampler(Pointer p) { super(p); }

public DistributedRandomSampler(
@Cast("size_t") long size,
@Cast("size_t") long num_replicas/*=1*/,
@Cast("size_t") long rank/*=0*/,
@Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); }
private native void allocate(
@Cast("size_t") long size,
@Cast("size_t") long num_replicas/*=1*/,
@Cast("size_t") long rank/*=0*/,
@Cast("bool") boolean allow_duplicates/*=true*/);
public DistributedRandomSampler(
@Cast("size_t") long size) { super((Pointer)null); allocate(size); }
private native void allocate(
@Cast("size_t") long size);

/** Resets the {@code DistributedRandomSampler} to a new set of indices. */
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size);
public native void reset();

/** Returns the next batch of indices. */
public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size);

/** Serializes the {@code DistributedRandomSampler} to the {@code archive}. */
public native void save(@ByRef OutputArchive archive);

/** Deserializes the {@code DistributedRandomSampler} from the {@code archive}. */
public native void load(@ByRef InputArchive archive);

/** Returns the current index of the {@code DistributedRandomSampler}. */
public native @Cast("size_t") @NoException(true) long index();
}
36 changes: 36 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/DistributedSampler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** A {@code Sampler} that selects a subset of indices to sample from and defines a
* sampling behavior. In a distributed setting, this selects a subset of the
* indices depending on the provided num_replicas and rank parameters. The
* {@code Sampler} performs a rounding operation based on the {@code allow_duplicates}
* parameter to decide the local sample count. */
@Name("torch::data::samplers::DistributedSampler<std::vector<size_t> >") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class DistributedSampler extends Sampler {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public DistributedSampler(Pointer p) { super(p); }


/** Set the epoch for the current enumeration. This can be used to alter the
* sample selection and shuffling behavior. */
public native void set_epoch(@Cast("size_t") long epoch);

public native @Cast("size_t") long epoch();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** Select samples sequentially. */
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class DistributedSequentialSampler extends DistributedSampler {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public DistributedSequentialSampler(Pointer p) { super(p); }

public DistributedSequentialSampler(
@Cast("size_t") long size,
@Cast("size_t") long num_replicas/*=1*/,
@Cast("size_t") long rank/*=0*/,
@Cast("bool") boolean allow_duplicates/*=true*/) { super((Pointer)null); allocate(size, num_replicas, rank, allow_duplicates); }
private native void allocate(
@Cast("size_t") long size,
@Cast("size_t") long num_replicas/*=1*/,
@Cast("size_t") long rank/*=0*/,
@Cast("bool") boolean allow_duplicates/*=true*/);
public DistributedSequentialSampler(
@Cast("size_t") long size) { super((Pointer)null); allocate(size); }
private native void allocate(
@Cast("size_t") long size);

/** Resets the {@code DistributedSequentialSampler} to a new set of indices. */
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size);
public native void reset();

/** Returns the next batch of indices. */
public native @ByVal SizeTVectorOptional next(@Cast("size_t") long batch_size);

/** Serializes the {@code DistributedSequentialSampler} to the {@code archive}. */
public native void save(@ByRef OutputArchive archive);

/** Deserializes the {@code DistributedSequentialSampler} from the {@code archive}. */
public native void load(@ByRef InputArchive archive);

/** Returns the current index of the {@code DistributedSequentialSampler}. */
public native @Cast("size_t") @NoException(true) long index();
}
50 changes: 50 additions & 0 deletions pytorch/src/gen/java/org/bytedeco/pytorch/StreamSampler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Targeted by JavaCPP version 1.5.9-SNAPSHOT: DO NOT EDIT THIS FILE

package org.bytedeco.pytorch;

import org.bytedeco.pytorch.Allocator;
import org.bytedeco.pytorch.Function;
import org.bytedeco.pytorch.Module;
import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;

import static org.bytedeco.pytorch.global.torch.*;


/** A sampler for (potentially infinite) streams of data.
*
* The major feature of the {@code StreamSampler} is that it does not return
* particular indices, but instead only the number of elements to fetch from
* the dataset. The dataset has to decide how to produce those elements. */
@Namespace("torch::data::samplers") @NoOffset @Properties(inherit = org.bytedeco.pytorch.presets.torch.class)
public class StreamSampler extends BatchSizeSampler {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public StreamSampler(Pointer p) { super(p); }

/** Constructs the {@code StreamSampler} with the number of individual examples that
* should be fetched until the sampler is exhausted. */
public StreamSampler(@Cast("size_t") long epoch_size) { super((Pointer)null); allocate(epoch_size); }
private native void allocate(@Cast("size_t") long epoch_size);

/** Resets the internal state of the sampler. */
public native void reset(@ByVal(nullValue = "c10::optional<size_t>(c10::nullopt)") SizeTOptional new_size);
public native void reset();

/** Returns a {@code BatchSize} object with the number of elements to fetch in the
* next batch. This number is the minimum of the supplied {@code batch_size} and
* the difference between the {@code epoch_size} and the current index. If the
* {@code epoch_size} has been reached, returns an empty optional. */
public native @ByVal BatchSizeOptional next(@Cast("size_t") long batch_size);

/** Serializes the {@code StreamSampler} to the {@code archive}. */
public native void save(@ByRef OutputArchive archive);

/** Deserializes the {@code StreamSampler} from the {@code archive}. */
public native void load(@ByRef InputArchive archive);
}
Loading

0 comments on commit c0d15db

Please sign in to comment.