Sign Up
Log In
Log In
or
Sign Up
Places
All Projects
Status Monitor
Collapse sidebar
home:birdwatcher:machinelearning
python-torch-stable
pytorch-optionally-use-hipblaslt-2-3-1.patch
Overview
Repositories
Revisions
Requests
Users
Attributes
Meta
File pytorch-optionally-use-hipblaslt-2-3-1.patch of Package python-torch-stable
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d534ec5a1..e815463f6 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -14,7 +14,7 @@ #include <c10/util/irange.h> #ifdef USE_ROCM -#if ROCM_VERSION >= 60000 +#ifdef HIPBLASLT #include <hipblaslt/hipblaslt-ext.hpp> #endif // until hipblas has an API to accept flags, we must use rocblas here @@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000 // only for rocm 5.7 where we first supported hipblaslt, it was difficult @@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< }; } // namespace +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) template <typename Dtype> void gemm_and_bias( bool transpose_mat1, @@ -1124,7 +1125,7 @@ template void gemm_and_bias( at::BFloat16* result_ptr, int64_t result_ld, GEMMAndBiasActivationEpilogue activation); - +#endif void scaled_gemm( char transa, char transb, diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index eb12bb350..068607467 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) enum GEMMAndBiasActivationEpilogue { None, RELU, diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 4ec35f59a..e28dc4203 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -9,7 +9,7 @@ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also // added bf16 support -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) #include <cublasLt.h> #endif @@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); /* Handles */ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); #endif diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 6913d2cd9..3d4276be3 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -29,7 +29,7 @@ namespace at::cuda { namespace { -#if defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if defined(USE_ROCM) && defined(HIPBLASLT) void createCublasLtHandle(cublasLtHandle_t *handle) { TORCH_CUDABLAS_CHECK(cublasLtCreate(handle)); } @@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() { return handle; } -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) cublasLtHandle_t getCurrentCUDABlasLtHandle() { #ifdef USE_ROCM c10::DeviceIndex device = 0; diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 3ba0d7612..dde1870cf 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -11,7 +11,7 @@ #include <ATen/cuda/tunable/GemmCommon.h> #ifdef USE_ROCM -#if ROCM_VERSION >= 50700 +#ifdef HIPBLASLT #include <ATen/cuda/tunable/GemmHipblaslt.h> #endif #include <ATen/cuda/tunable/GemmRocblas.h> @@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> { } #endif -#if defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if defined(USE_ROCM) && defined(HIPBLASLT) static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env == nullptr || strcmp(env, "1") == 0) { // disallow tuning of hipblaslt with c10::complex @@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T> } #endif -#if defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if defined(USE_ROCM) && defined(HIPBLASLT) static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env == nullptr || strcmp(env, "1") == 0) { // disallow tuning of hipblaslt with c10::complex diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 29e5c5e3c..df56f3d7f 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -155,7 +155,7 @@ enum class Activation { GELU, }; -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) { switch (a) { case Activation::None: @@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() { #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { +#if defined(HIPBLASLT) hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); std::string device_arch = prop->gcnArchName; static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; @@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) { } } TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); +#endif return false; } #endif @@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned<Tensor> self_; if (&result != &self) { -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT) // Strangely, if mat2 has only 1 row or column, we get // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] @@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma } self__sizes = self_->sizes(); } else { -#if defined(USE_ROCM) && ROCM_VERSION >= 50700 +#if defined(USE_ROCM) && defined(HIPBLASLT) useLtInterface = !disable_addmm_cuda_lt && result.dim() == 2 && result.is_contiguous() && isSupportedHipLtROCmArch(self.device().index()) && @@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT)) if (useLtInterface) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); at::native::resize_output(amax, {}); -#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) +#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT)) cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform."); #endif -#if defined(USE_ROCM) && ROCM_VERSION >= 60000 +#if defined(USE_ROCM) && defined(HIPBLASLT) // rocm's hipblaslt does not yet support amax, so calculate separately auto out_float32 = out.to(kFloat); out_float32.abs_(); diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a96075245..d6149ef77 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1282,6 +1282,9 @@ if(USE_ROCM) if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0") list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) endif() + if(hipblaslt_FOUND) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT) + endif() if(HIPBLASLT_CUSTOM_DATA_TYPE) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE) endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index f6ca263c5..53eb0b63c 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -156,7 +156,7 @@ if(HIP_FOUND) find_package_and_print_version(rocblas REQUIRED) find_package_and_print_version(hipblas REQUIRED) if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") - find_package_and_print_version(hipblaslt REQUIRED) + find_package_and_print_version(hipblaslt) endif() find_package_and_print_version(miopen REQUIRED) if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0") @@ -191,7 +191,7 @@ if(HIP_FOUND) # roctx is part of roctracer find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) - if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0") + if(hipblastlt_FOUND) # check whether hipblaslt is using its own datatype set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc") file(WRITE ${file} ""
Locations
Projects
Search
Status Monitor
Help
OpenBuildService.org
Documentation
API Documentation
Code of Conduct
Contact
Support
@OBShq
Terms
openSUSE Build Service is sponsored by
The Open Build Service is an
openSUSE project
.
Sign Up
Log In
Places
Places
All Projects
Status Monitor