Skip to content

Commit

Permalink
Merge branch 'main' into please_dont_modify_this_branch_unless_you_ar…
Browse files Browse the repository at this point in the history
…e_just_merging_with_main__
  • Loading branch information
NicolasHug authored Nov 12, 2024
2 parents 1966348 + 7d077f1 commit e565ca4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 77 deletions.
9 changes: 9 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively.
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.

.. warning::

When a dataset object is created with ``download=True``, the files are first
downloaded and extracted in the root directory. This download logic is not
multi-process safe, so it may lead to conflicts / race conditions if it is
run within a distributed setting. In distributed mode, we recommend creating
a dummy dataset object to trigger the download logic *before* setting up
distributed mode.

Image classification
~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 1 addition & 3 deletions torchvision/csrc/io/image/cpu/decode_webp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ torch::Tensor decode_webp(

auto decoded_data =
decoding_func(encoded_data_p, encoded_data_size, &width, &height);

TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB[A] failed.");

auto deleter = [decoded_data](void*) { WebPFree(decoded_data); };
auto out = torch::from_blob(
decoded_data, {height, width, num_channels}, deleter, torch::kUInt8);
decoded_data, {height, width, num_channels}, torch::kUInt8);

return out.permute({2, 0, 1});
}
Expand Down
87 changes: 14 additions & 73 deletions torchvision/csrc/ops/mps/mps_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace ops {

namespace mps {

static const char* METAL_VISION = R"VISION_METAL(
static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL(
#include <metal_atomic>
#include <metal_stdlib>
Expand All @@ -26,46 +26,15 @@ inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
template <typename T>
inline void atomic_add_float( device T* data_ptr, const T val)
inline void atomic_add_float(device float* data_ptr, const float val)
{
#if __METAL_VERSION__ >= 300
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
#else
// Custom atomic addition implementation
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
// Create an atomic uint pointer for atomic transaction.
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
// Create necessary storage.
uint fetched_uint, assigning_uint;
T fetched_float, assigning_float;
// Replace the value in atom_var with 0 and return the previous value in atom_var.
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
// Read out the previous value as float.
fetched_float = *( (thread T*) &fetched_uint );
// Do addition and represent the addition result in uint for atomic transaction.
assigning_float = fetched_float + val;
assigning_uint = *((thread uint*) &assigning_float);
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
// Try to assign 0 and get the previously assigned addition result.
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
T fetched_float_again = *( (thread T*) &fetched_uint_again );
// Re-add again
fetched_float = *((thread T*) &(fetched_uint));
// Previously assigned addition result + addition result from other threads.
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float);
}
#endif
atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
}
inline void atomic_add_float(device half* data_ptr, const half val)
{
atomic_fetch_add_explicit((device atomic_float*) data_ptr, static_cast<float>(val), memory_order_relaxed);
}
template <typename T, typename integer_t>
Expand Down Expand Up @@ -1061,40 +1030,12 @@ REGISTER_PS_ROI_POOL_OP(half, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
)VISION_METAL";

static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> visionLibrary = nil;
if (visionLibrary) {
return visionLibrary;
}

NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
return visionLibrary;
}

static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
)VISION_METAL");

psoCache[kernel] = pso;
return pso;
static id<MTLComputePipelineState> visionPipelineState(
id<MTLDevice> device,
const std::string& kernel) {
return lib.getPipelineStateForFunc(kernel);
}

} // namespace mps
Expand Down
1 change: 0 additions & 1 deletion torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@

float spatial_scale_f = static_cast<float>(spatial_scale);

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());

if (grad.numel() == 0) {
Expand Down

0 comments on commit e565ca4

Please sign in to comment.