#pragma once #include #include #include // TODO: No need to have this whole header, we can just put it all in // the cpp file namespace at { namespace cuda { namespace detail { // Set the callback to initialize Magma, which is set by // torch_cuda_cu. This indirection is required so magma_init is called // in the same library where Magma will be used. TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); TORCH_CUDA_CPP_API bool hasPrimaryContext(int64_t device_index); TORCH_CUDA_CPP_API c10::optional getDeviceIndexWithPrimaryContext(); // The real implementation of CUDAHooksInterface struct CUDAHooks : public at::CUDAHooksInterface { CUDAHooks(at::CUDAHooksArgs) {} void initCUDA() const override; Device getDeviceFromPtr(void* data) const override; bool isPinnedPtr(void* data) const override; const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; bool hasCUDA() const override; bool hasMAGMA() const override; bool hasCuDNN() const override; bool hasCuSOLVER() const override; bool hasROCM() const override; const at::cuda::NVRTC& nvrtc() const override; int64_t current_device() const override; bool hasPrimaryContext(int64_t device_index) const override; Allocator* getCUDADeviceAllocator() const override; Allocator* getPinnedMemoryAllocator() const override; bool compiledWithCuDNN() const override; bool compiledWithMIOpen() const override; bool supportsDilatedConvolutionWithCuDNN() const override; bool supportsDepthwiseConvolutionWithCuDNN() const override; bool supportsBFloat16ConvolutionWithCuDNNv8() const override; bool hasCUDART() const override; long versionCUDART() const override; long versionCuDNN() const override; std::string showConfig() const override; double batchnormMinEpsilonCuDNN() const override; int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override; void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override; int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override; void cuFFTClearPlanCache(int64_t device_index) const override; int getNumGPUs() const override; void deviceSynchronize(int64_t device_index) const override; }; }}} // at::cuda::detail