#pragma once #include #include #include #include /* * [Note: hacky wrapper removal for optional tensor] * * The kernel implementation takes an optional tensor marked in the schema as * Tensor? but the C++ function takes Tensor instead of the optional * expected by the dispatcher. * * To remove the hacky wrapper, the C++ function is changed to take * optional and unwrap the Tensor value at the beginning of * the function, e.g.: * > c10::MaybeOwned weight_maybe_owned = * > at::borrow_from_optional_tensor(weight_opt); * > const Tensor& weight = *weight_maybe_owned; * * We may want to make the kernel handle optional directly without * going through the creation of a default-constructed Tensor in * at::borrow_from_optional_tensor. */ /* * [Note: hacky wrapper removal for TensorOptions] * * The kernel implementation takes a TensorOptions argument but the dispatcher * expects separate arguments for dtype, layout, device, pin_memory. * * To remove the hacky wrapper, the kernel implementation is changed to take * the 4 arguments (dtype, layout, device, pin_memory), and assemble the * TensorOptions value at the beginning of the function, e.g.: * > TensorOptions options = TensorOptions().dtype(dtype).layout(layout) * > .device(device).pinned_memory(pin_memory); * * We may want make the kernel handle these parameters directly without going * through the creation of a TensorOptions value. */ namespace c10 { namespace impl { TORCH_API void common_device_check_failure(optional& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName); inline void check_and_update_common_device(optional& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { // TODO: Remove this once the following issue is addressed: // https://github.com/pytorch/pytorch/issues/57380 if (!tensor.defined()) { return; } if (!common_device.has_value()) { common_device = tensor.device(); return; } if (C10_UNLIKELY(common_device != tensor.device())) { common_device_check_failure(common_device, tensor, methodName, argName); } } inline void check_and_update_common_device(optional& common_device, const optional& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) { if (tensor.has_value()) { check_and_update_common_device(common_device, tensor.value(), methodName, argName); } } inline void check_and_update_common_device(optional& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { for (const auto& tensor : tensors) { check_and_update_common_device(common_device, tensor, methodName, argName); } } inline void check_and_update_common_device(optional& common_device, const List>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) { for (const auto& tensor : tensors) { check_and_update_common_device(common_device, tensor, methodName, argName); } } } // namespace impl } // namespace c10