Skip to content

Commit

Permalink
Update on "Enable stateless XNNPACK linear."
Browse files Browse the repository at this point in the history
Optimal usage of the linear operator would require weight prepacking. If we have done our job properly, JIT must have caught and replaced linear operations on mobile with their corresponding prepacked versions, hence enabling said optimal usage.  Still, if we somehow end up in at::native::linear for whatever reason, it is still more efficient to go through XNNPACK than the alternatives of at::addmm and at::matmul.

Differential Revision: [D20821863](https://our.internmc.facebook.com/intern/diff/D20821863)

[ghstack-poisoned]
  • Loading branch information
Ashkan Aliabadi committed Apr 22, 2020
2 parents 8d01b21 + fd083ad commit 9ec51d7
Show file tree
Hide file tree
Showing 90 changed files with 2,222 additions and 1,628 deletions.
6 changes: 6 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ genrule(
"torch/csrc/autograd/generated/VariableType_3.cpp",
"torch/csrc/autograd/generated/VariableType_4.cpp",
# "torch/csrc/autograd/generated/VariableTypeEverything.cpp",
"torch/csrc/autograd/generated/ProfiledType_0.cpp",
"torch/csrc/autograd/generated/ProfiledType_1.cpp",
"torch/csrc/autograd/generated/ProfiledType_2.cpp",
"torch/csrc/autograd/generated/ProfiledType_3.cpp",
"torch/csrc/autograd/generated/ProfiledType_4.cpp",
# "torch/csrc/autograd/generated/ProfiledTypeEverything.cpp",
"torch/csrc/autograd/generated/RegistrationDeclarations.h",
"torch/csrc/autograd/generated/Functions.h",
"torch/csrc/autograd/generated/Functions.cpp",
Expand Down
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ cmake_dependent_option(
cmake_dependent_option(
USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
# NB: USE_NCCL is intentionally left independent from USE_DISTRIBUTED, because
# DataParallel also uses NCCL.
option(USE_TBB "Use TBB" OFF)
option(ONNX_ML "Enable traditional ONNX ML API." ON)

Expand Down
26 changes: 9 additions & 17 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,18 @@
#include <c10/util/Exception.h>
#include <ATen/core/DeprecatedTypeProperties.h>

// Workaround for C10_UNUSED because CUDA 9.2 fails to handle unused attribute in the type aliasing context.
// Keep name long and verbose to avoid macro collisions.
#if defined(__CUDACC__) && CUDA_VERSION <= 9200
#define C10_UNUSED_DISPATCH_CUDA9_WORKAROUND
#else
#define C10_UNUSED_DISPATCH_CUDA9_WORKAROUND C10_UNUSED
#endif // defined(__CUDACC__) && CUDA_VERSION <= 9200

#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
using scalar_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = type; \
return __VA_ARGS__(); \
#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
return __VA_ARGS__(); \
}

#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_enum, underlying_type, ...) \
case enum_type: { \
const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = underlying_enum; \
using scalar_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = type; \
using underlying_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = underlying_type; \
return __VA_ARGS__(); \
case enum_type: { \
const auto& UNDERLYING_TYPE C10_UNUSED = underlying_enum; \
using scalar_t C10_UNUSED = type; \
using underlying_t C10_UNUSED = underlying_type; \
return __VA_ARGS__(); \
}

// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
Expand Down
8 changes: 3 additions & 5 deletions aten/src/ATen/core/DistributionsHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <ATen/core/Array.h>
#include <c10/util/Half.h>
#include <c10/util/Optional.h>
#include <c10/macros/Macros.h>

#include <type_traits>
#include <limits>
#include <cmath>
Expand Down Expand Up @@ -214,11 +216,7 @@ struct exponential_distribution {
}

template <typename RNG>
inline T operator()(RNG* generator) {
// Follows numpy exponential for the case when lambda is zero.
if (lambda == static_cast<T>(0.0)) {
return static_cast<T>(0.0);
}
__ubsan_ignore_float_divide_by_zero__ inline T operator()(RNG* generator) {
uniform_real_distribution<T> uniform(0.0, 1.0);
dist_acctype<T> sample = uniform(generator);
return static_cast<T>(-1.0) / lambda * ::log(static_cast<T>(1.0)-sample);
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "Tensor";
}
if (auto ndim = value->sizes().size()) {
bool has_valid_strides_info =
value->strides().isComplete() && value->strides().size() == ndim;

out << "(";
for (size_t i = 0; i < *ndim; ++i) {
if (i > 0) {
Expand All @@ -30,6 +33,9 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else {
out << "*";
}
if (has_valid_strides_info) {
out << ":" << *value->strides()[i];
}
}
out << ")";
}
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/cpu/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <ATen/cpu/vec256/intrinsics.h>
#include <ATen/cpu/vec256/vec256_base.h>
#include <ATen/native/quantized/affine_quantizer.h>
#include <ATen/quantized/Quantizer.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
#include <c10/util/qint32.h>
Expand Down Expand Up @@ -212,7 +212,7 @@ inline void __attribute__((always_inline)) QuantizeAvx2(
dst[i] = nearbyint(clipped);
}
#else
at::native::quantize_vec<T>(
at::quantize_vec<T>(
1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
#endif
}
Expand Down Expand Up @@ -278,7 +278,7 @@ struct Vec256<c10::qint32> : public Vec256qi {
float inverse_scale) {
Vec256<c10::qint32> retval;
auto rhs_data = (__m256)rhs[0];
at::native::quantize_vec<c10::qint32, /*precision=*/32>(
at::quantize_vec<c10::qint32, /*precision=*/32>(
scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
return retval;
}
Expand Down Expand Up @@ -1094,7 +1094,7 @@ struct Vec256QuantizedConverter {
for (int i = 0; i < float_num_vecs(); ++i) {
for (int j = 0; j < 8; ++j) {
rv[i][j] =
at::native::dequantize_val<T>(scale[j], zero_point[j], T(vals[8 * i + j]));
at::dequantize_val<T>(scale[j], zero_point[j], T(vals[8 * i + j]));
}
}
return rv;
Expand Down Expand Up @@ -1152,7 +1152,7 @@ struct Vec256<c10::qint32> : public Vec256QuantizedConverter<
rhs[i].store(&float_vals[i * 8], 8);
}

at::native::quantize_vec<c10::qint32, /*precision=*/32>(
at::quantize_vec<c10::qint32, /*precision=*/32>(
scale,
zero_point,
float_vals.data(),
Expand Down Expand Up @@ -1284,7 +1284,7 @@ struct Vec256<c10::qint8> : public Vec256QuantizedConverter<
rhs[i].store(&float_vals[i * 8], 8);
}

at::native::quantize_vec<c10::qint8>(
at::quantize_vec<c10::qint8>(
scale,
zero_point,
float_vals.data(),
Expand Down Expand Up @@ -1404,7 +1404,7 @@ struct Vec256<c10::quint8> : public Vec256QuantizedConverter<
rhs[i].store(&float_vals[i * 8], 8);
}

at::native::quantize_vec<c10::quint8>(
at::quantize_vec<c10::quint8>(
scale,
zero_point,
float_vals.data(),
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ inline bool CUDA_tensor_apply2(at::Tensor a,
const Op op,
TensorArgType aType = TensorArgType::ReadWrite,
TensorArgType bType = TensorArgType::ReadOnly) {
checkDeviceType("CUDA_tensor_apply2", {a, b}, DeviceType::CUDA);
checkBackend("CUDA_tensor_apply2", {a, b}, Backend::CUDA);
int64_t totalElements = a.numel();

if (totalElements != b.numel()) {
Expand Down
10 changes: 3 additions & 7 deletions aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,12 @@ def check_all_files_written(self):
def backend_to_devicetype(backend):
if backend == 'QuantizedCPU':
return 'CPU'
elif backend == 'QuantizedCUDA':
return 'CUDA'
return backend

backends = ['CPU', 'CUDA']
densities = ['Dense', 'Sparse', 'Mkldnn'] # TODO: layout instead of densities?

quantized_backends = ['QuantizedCPU', 'QuantizedCUDA']
quantized_backends = ['QuantizedCPU']

# scalar_name, c_type, accreal, is_floating_type
quantized_scalar_types = [
Expand Down Expand Up @@ -213,8 +211,6 @@ def backend_to_devicetype(backend):
def is_whitelisted_backend(backend):
return options.backend_whitelist is None or backend in options.backend_whitelist

def is_cuda_backend(backend):
return backend in ("QuantizedCUDA", "CUDA")

def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
Expand Down Expand Up @@ -298,7 +294,7 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi
top_env['type_ids'].append(tag + ',')

env['legacy_th_headers'] = []
if is_cuda_backend(backend):
if backend == 'CUDA':
env['extra_cuda_headers'] = []
env['extra_cuda_headers'].append('#include <ATen/DeviceGuard.h>')
if options.rocm:
Expand Down Expand Up @@ -407,7 +403,7 @@ def declare_outputs():
if not is_whitelisted_backend(full_backend):
continue
fm = file_manager
if is_cuda_backend(backend):
if backend == 'CUDA':
fm = cuda_file_manager
for kind in ["Type"]:
if kind != 'Type' and density == "Sparse":
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
}

if (self.is_quantized() && !src.is_quantized()) {
return quantized_copy_from_float_cpu_(self, src);
return quantized_copy_from_float_(self, src);
}

if (self.is_quantized() && src.is_quantized()) {
Expand Down
83 changes: 50 additions & 33 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,60 @@ bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol,
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
}

// Note [closeness]
// A number A is close to B when either:
//
// (1) A is equal to B, with NaNs comparing equal when equal_nan is true.
// (2) The error abs(A - B) is finite and less than the max error
// (atol + abs(rtol * B)).
//
// Note that this is consistent with NumPy's isclose but divergent from
// Python's isclose, which computes the max error symmetrically as
// max(rtol * max(abs(A), abs(B)), atol).
// TODO: use bitwise operator overloads once we add them
// TODO: revisit complex inputs and equal_nan=true after
// https://github.com/numpy/numpy/issues/15959 is resolved
Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
// TODO: use bitwise operator overloads once we add them

TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type())

// The original formula `atol + rtol * other.abs()` works incorrectly when
// `other` has integral dtype and `other == min_value` and `abs(min_value)` is negative:
// std::abs(std::numeric_limits<int64_t>::lowest()) == std::numeric_limits<int64_t>::lowest() < 0
auto max_error = atol + (rtol * other).abs();

// `max_error` could be a float or double depending on the type of the input
// tensors.
// Specifically, if other is an int tensor, multiplying by rtol results in
// float tensor.
// It is also possible for parameters to be 'wrapped_number's, in which case
// max_error could be promoted to double when actual error is still a float.
Tensor actual_error;
if (actual_error.scalar_type() != max_error.scalar_type()) {
// To silence ASAN that does not like (x - std::numeric_limits<int64_t>::lowest())
actual_error = (self - other.to(max_error.scalar_type())).abs();
} else {
actual_error = (self - other).abs();
TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type());
TORCH_CHECK(!(self.is_complex() && equal_nan),
"isclose with equal_nan=True is not supported for complex inputs.");

// Checks that rtol and atol are non-negative
// Note: consistent with Python's isclose but divergent from NumPy's, which
// allows negative atol and rtol.
TORCH_CHECK(rtol >= 0, "rtol must be greater than or equal to zero, but got ", rtol);
TORCH_CHECK(atol >= 0, "atol must be greater than or equal to zero, but got ", atol);

// Computes equality closeness
Tensor close = self == other;
if (equal_nan && self.is_floating_point()) {
close.__ior__((self != self).__iand__(other != other));
}

auto close = actual_error <= max_error;
// Note [closeness error computation]
// atol and rtol are provided as doubles, so the computation
// rtol * other will produce a float or complex tensor.
// When the difference (self - other) is compared to it then the
// tensor representing the difference will also be cast to float or complex.
// However, since (self - other) in uint8 is very likely to produce a
// negative value, this moves the cast forward so the difference is
// always computed in a float or complex type.
// If the values of the integer tensors cannot be exactly represented
// by the default scalar type then this may cause an incorrect result.

// Computes allowed and actual error
Tensor cast_other;
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
cast_other = other.to(at::get_default_dtype());
} else {
cast_other = other;
}
Tensor allowed_error = atol + (rtol * cast_other).abs();
Tensor actual_error = (self - cast_other).abs();

if (isFloatingType(self.scalar_type()) && isFloatingType(other.scalar_type())) {
// Handle +/-inf
close.__ior__(self == other);
close.__iand__((self == INFINITY) == (other == INFINITY));
close.__iand__((self == -INFINITY) == (other == -INFINITY));
// Computes finite closeness
close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));

if (equal_nan) {
close.__ior__((self != self).__and__((other != other)));
}
}
return close;
}

Expand Down Expand Up @@ -87,8 +105,7 @@ Tensor isfinite(const Tensor& self) {

// Note: a complex value is finite iff both parts are finite
if (self.is_complex()) {
const auto float_type = c10::toValueType(self.scalar_type());
return at::isfinite(self.abs().to(float_type));
return at::isfinite(self.abs());
}

return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "isfinite", [&]() {
Expand Down
13 changes: 3 additions & 10 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,9 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
copy_stream));
}
} else {
auto dtype = iter.dtype(0);
if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
});
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
});
}

if (src_device != dst_device) {
Expand Down
6 changes: 0 additions & 6 deletions aten/src/ATen/native/cuda/DistributionExponentialKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ void exponential_kernel(TensorIterator& iter, double lambda_, c10::optional<Gene
if (std::is_same<scalar_t, double>::value) {
// define lambda for exponential transformation
auto exponential_func = [lambda, nextafter_1_0_double] __device__ (accscalar_t rand) {
if (lambda == static_cast<accscalar_t>(0.0)) {
return static_cast<scalar_t>(0.0);
}
accscalar_t sample;
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
// Hence, squash the 1 to just below 1.
Expand All @@ -60,9 +57,6 @@ void exponential_kernel(TensorIterator& iter, double lambda_, c10::optional<Gene
} else {
// use __logf fast approximation for peak bandwidth
auto exponential_func = [lambda, nextafter_1_0_float] __device__ (accscalar_t rand) {
if (lambda == static_cast<accscalar_t>(0.0)) {
return static_cast<scalar_t>(0.0);
}
accscalar_t sample;
if(rand == static_cast<accscalar_t>(1.0)) {
sample = __logf(nextafter_1_0_float);
Expand Down
Loading

0 comments on commit 9ec51d7

Please sign in to comment.