Sign Up
Log In
Log In
or
Sign Up
Places
All Projects
Status Monitor
Collapse sidebar
home:birdwatcher:machinelearning
python-torch-stable
pytorch-rocm-do-not-use-aotriton-if-not-require...
Overview
Repositories
Revisions
Requests
Users
Attributes
Meta
File pytorch-rocm-do-not-use-aotriton-if-not-required.patch of Package python-torch-stable
--- a/CMakeLists.txt 2024-05-29 17:15:01.000000000 +0200 +++ b/CMakeLists.txt 2024-08-26 22:03:34.930907505 +0200 @@ -771,7 +771,13 @@ USE_MEM_EFF_ATTENTION "Enable memory-efficient attention for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA" OFF) + "USE_CUDA OR USE_ROCM" OFF) + +if(USE_ROCM) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) + include(cmake/External/aotriton.cmake) + endif() +endif() if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") --- a/cmake/Dependencies.cmake 2024-05-29 17:15:01.000000000 +0200 +++ b/cmake/Dependencies.cmake 2024-08-26 22:02:16.051084133 +0200 @@ -1334,11 +1334,6 @@ else() message(STATUS "Disabling Kernel Assert for ROCm") endif() - - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) - if(USE_CUDA) - caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) - endif() else() caffe2_update_option(USE_ROCM OFF) endif() --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp 2024-08-25 14:27:15.042669638 +0200 +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp 2024-08-26 21:41:25.003719366 +0200 @@ -22,7 +22,10 @@ #include <functional> #if USE_ROCM +#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include <aotriton/flash.h> +#define USE_AOTRITON 1 +#endif #endif /** @@ -187,6 +190,7 @@ using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -197,6 +201,9 @@ return false; } #else + return false; +#endif +#else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version<sm80, sm90>(dprops)) { if (debug) { @@ -217,6 +224,21 @@ // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; +#if USE_ROCM +#if USE_AOTRITON + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; + } +#else + return false; +#endif +#else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version<sm50, sm90>(dprops)) { if (debug) { @@ -230,6 +252,7 @@ return false; } return true; +#endif } bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
Locations
Projects
Search
Status Monitor
Help
OpenBuildService.org
Documentation
API Documentation
Code of Conduct
Contact
Support
@OBShq
Terms
openSUSE Build Service is sponsored by
The Open Build Service is an
openSUSE project
.
Sign Up
Log In
Places
Places
All Projects
Status Monitor