import itertools import textwrap from dataclasses import dataclass from typing import List, Optional, Tuple, Union from typing_extensions import Literal import torchgen.api.cpp as cpp import torchgen.api.meta as meta import torchgen.api.structured as structured from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, ConstRefCType, CppSignature, CppSignatureGroup, DispatcherSignature, Expr, kernel_signature, MutRefCType, NamedCType, NativeSignature, tensorT, ) from torchgen.context import method_with_native_function, native_function_manager from torchgen.model import ( Argument, BackendIndex, DeviceCheckType, DispatchKey, gets_generated_out_inplace_wrapper, is_cuda_dispatch_key, NativeFunction, NativeFunctionsGroup, SchemaKind, TensorOptionsArguments, ) from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import assert_never, mapMaybe, Target def gen_registration_headers( backend_index: BackendIndex, per_operator_headers: bool, rocm: bool, ) -> List[str]: if per_operator_headers: headers = ["#include "] else: headers = ["#include "] if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta): headers.append("#include ") elif backend_index.dispatch_key == DispatchKey.CUDA: if rocm: headers.append("#include ") else: headers.append("#include ") elif backend_index.dispatch_key == DispatchKey.MPS: headers.append("#include ") elif per_operator_headers: headers += [ "#include ", "#include ", "#include ", "#include ", ] else: headers.append("#include ") return headers def gen_empty_impl_names( backend_index: BackendIndex, ) -> Tuple[Optional[str], Optional[str]]: empty_impl = None empty_strided_impl = None if backend_index.dispatch_key in ( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.MPS, ): dispatch = str(backend_index.dispatch_key).lower() empty_impl = f"at::detail::empty_{dispatch}" empty_strided_impl = f"at::detail::empty_strided_{dispatch}" elif backend_index.dispatch_key in ( DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, ): empty_impl = "at::empty" empty_strided_impl = "at::empty_strided" return empty_impl, empty_strided_impl def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: if backend_index.dispatch_key == DispatchKey.Meta: empty_options = "options.device(at::kMeta)" else: empty_options = "options" empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index) if empty_impl is None: return [] return [ f""" Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ if (strides.empty()) {{ return {empty_impl}(sizes, {empty_options}); }} else {{ return {empty_strided_impl}(sizes, strides, {empty_options}); }} }} """ ] def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]: _, empty_strided_impl = gen_empty_impl_names(backend_index) return ( [] if empty_strided_impl is None else [ f""" c10::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ if (out.strides() != strides) {{ return {empty_strided_impl}(sizes, strides, options); }} return c10::nullopt; }} """ ] ) def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: # The function isn't used by this key (since only functional ops have a kernel for this key), # so we need to not include it to avoid a defined-but-not-used error. return [] return [ """ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { TORCH_CHECK(options.dtype() == out.dtype(), "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); TORCH_CHECK(options.device() == out.device(), "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); const bool resized = at::native::resize_output(out, sizes); // Only restride if a resize occurred; otherwise we ignore the (advisory) // strides from the meta function and directly use the output tensor's // preexisting strides if (resized) { if (!strides.empty()) { TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); // TODO: avoid the redispatch here out.as_strided_(sizes, strides); } else if (options.memory_format_opt().has_value()) { out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); } } } """ ] def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]: return [ """ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { // These checks are needed on those operators that: // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') // For other operators (e.g. 'add'), 'TensorIterator' already checks // these things separately. TORCH_CHECK(options.dtype() == self.dtype(), "Bad in-place call: ", "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); TORCH_CHECK(options.device() == self.device(), "Bad in-place call: ", "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); TORCH_CHECK(sizes == self.sizes(), "Bad in-place call: ", "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); } """ ] def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: return [ *gen_create_out_helper(backend_index), *gen_resize_out_helper(backend_index), *gen_check_inplace_helper(backend_index), *gen_maybe_create_proxy_helper(backend_index), ] # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). # # - The primary function of this file is to register all of the # implementations for the given dispatch key to the dispatcher, # so they are available for use in PyTorch. If dispatch is # None, we generate schema (def) registrations and catchall # registrations. # - The secondary function of this file is to generate a wrapper # around functions. In CPUType these wrappers do nothing # (and should be removed), but in other cases they handle # DeviceGuard. A small extra benefit of wrappers is they # are not overloaded, so they can be used in the registration # API without having to disambiguate which overload you want # (as would be the case if you directly registered native:: # functions). # - The tertiary function of this file is to generate *static* # cpp API bindings which can be used to bypass dispatcher # directly to kernels, but with user-friendly cpp-style API @dataclass(frozen=True) class RegisterDispatchKey: backend_index: BackendIndex target: Union[ Literal[Target.ANONYMOUS_DEFINITION], Literal[Target.NAMESPACED_DEFINITION], Literal[Target.NAMESPACED_DECLARATION], Literal[Target.REGISTRATION], ] # Selector object to determine which operators to generate # registration code for. selector: SelectiveBuilder # Whether or not we are actually code-genning for ROCm rocm: bool # Whether or not to generate symint registrations or not. External users # of codegen who don't care about symints can set this to false to get # non-SymInt codegen symint: bool # The class that all unstructured native functions live under. This is used to improve # compiler error messages when a kernel writer adds a native function with the wrong signature. # This is only used in unstructured kernels, since structured kernels already live in a class. # Finally, this field is currently Optional because it is only used by external backends. # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating # all of the existing kernel signatures scattered across aten/src/ATen/native. class_method_name: Optional[str] # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. skip_dispatcher_op_registration: bool @staticmethod def gen_device_check( type: DeviceCheckType, args: List[Argument], method_name: str ) -> str: if type == DeviceCheckType.NoCheck: return " // No device check\n" device_check = "c10::optional common_device = nullopt;\n" device_check += "(void)common_device; // Suppress unused variable warning\n" for arg in args: # Only tensor like arguments are eligible if arg.type.is_tensor_like(): device_check += f""" c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");""" return device_check @method_with_native_function def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: if isinstance(f, NativeFunctionsGroup): g: NativeFunctionsGroup = f # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. # gen_structured() has special logic to handle auto-generated kernels. if g.structured: return self.gen_structured(g) else: return list( mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) ) elif isinstance(f, NativeFunction): r = self.gen_unstructured(f) return [] if r is None else [r] else: assert_never(f) def wrapper_kernel_sig( self, f: NativeFunction ) -> Union[NativeSignature, DispatcherSignature]: # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. return DispatcherSignature.from_schema( f.func, prefix=f"wrapper_{f.func.name.overload_name}_", symint=self.symint ) def gen_out_inplace_wrapper( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] ) -> Optional[str]: if g is None: return None k = f.func.kind() if k is SchemaKind.inplace: copy_op = "at::_copy_from" elif k is SchemaKind.out: copy_op = "at::_copy_from_and_resize" else: raise AssertionError("gen_out_inplace_wrapper called on a functional op") sig = self.wrapper_kernel_sig(f) name = sig.name() func_res = f"{name}_tmp" return_names = cpp.return_names(f) if len(return_names) > 1: updates = "\n ".join( f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" for i, ret_name in enumerate(return_names) ) returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' else: ret_name = return_names[0] updates = f"{copy_op}({func_res}, {ret_name});" returns = ret_name functional_sig = self.wrapper_kernel_sig(g.functional) wrapper_name = sig.name() return f"""\ {sig.defn(name=wrapper_name)} {{ auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); {updates} return {returns}; }} """ def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: metadata = self.backend_index.get_kernel(g) if self.backend_index.dispatch_key == DispatchKey.Meta: assert not self.backend_index.has_kernel(g.out), ( "Do not explicitly specify Meta dispatch key on structured " "functions, they will be automatically generated for you" ) elif ( self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional ): assert not self.backend_index.has_kernel(g.out), ( "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " "functions, they will be automatically generated for you" ) elif metadata is None or not metadata.structured: return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) structured_gen = StructuredRegisterDispatchKey( self.backend_index, self.target, self.selector, self.rocm, self.symint, self.class_method_name, self.skip_dispatcher_op_registration, g, ) return list(mapMaybe(structured_gen.gen_one, g.functions())) def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None ) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if ( self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1 ): inplace_meta = True elif ( not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper(f, g, self.backend_index) ): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if ( self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f) ): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ", ".join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False ) # TODO: dedupe this with the structured codegen if self.target is Target.NAMESPACED_DECLARATION: result = "" for cpp_sig in cpp_sig_group.signatures(symint=self.symint): result += f"TORCH_API {cpp_sig.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = "" for cpp_sig in cpp_sig_group.signatures(symint=self.symint): result += generate_defn(cpp_sig) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" kernel_sig = kernel_signature(f, self.backend_index) args_exprs_str = ", ".join( e.expr for e in translate( sig.arguments(), kernel_sig.arguments(), method=False ) ) device_check = " // No device check\n" # Backends that require device guards presumably also require device checks. if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional ) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name ) device_guard = "// DeviceGuard omitted" # default if f.device_guard and self.backend_index.device_guard: has_tensor_options = any( isinstance(a, TensorOptionsArguments) for a in f.func.arguments.non_out ) if has_tensor_options: # kernel is creating a tensor device_guard = """ const DeviceGuard device_guard(device_or_default(device));""" # CUDA requires special handling if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_guard = ( f"globalContext().lazyInitCUDA();\n{device_guard}" ) else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = ( [f.func.arguments.self_arg.argument] if f.func.arguments.self_arg is not None else [] ) candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional, ) # Only tensor like arguments are eligible device_of = next( ( f"{a.name}" for a in candidate_args if a.type.is_tensor_like() ), None, ) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration or self.skip_dispatcher_op_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # STRUCTURED # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @dataclass(frozen=True) class StructuredRegisterDispatchKey(RegisterDispatchKey): g: NativeFunctionsGroup def gen_class_set_output_functions( self, k: SchemaKind, parent_class: str, generate_super: bool ) -> str: if generate_super: set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);" else: set_output_super = "" def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: maybe_star = "*" if k is SchemaKind.functional else "" return f""" void set_output_{name}( int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names ) override {{ {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} if (!names.empty()) {{ namedinference::propagate_names({maybe_star}outputs_[output_idx], names); }} // super must happen after, so that downstream can use maybe_get_output // to retrieve the output {textwrap.indent(set_output_super, " ")} }} """ return f""" {gen_set_output_function("strided", maybe_create_proxy=True)} {gen_set_output_function("raw_strided", maybe_create_proxy=False)} """ def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: if self.backend_index.dispatch_key in [ DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutogradNonFunctional, ]: maybe_set_guard = """ auto current_device = guard_.current_device(); if (C10_UNLIKELY(current_device.has_value())) { TORCH_INTERNAL_ASSERT(*current_device == options.device(), "structured kernels don't support multi-device outputs"); } else { guard_.reset_device(options.device()); } """ maybe_set_guard_line = maybe_set_guard + "\n" else: maybe_set_guard_line = maybe_set_guard = "" if maybe_create_proxy: create_proxy = """ auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); if (C10_UNLIKELY(maybe_proxy.has_value())) { proxy_outputs_[output_idx] = c10::ExclusivelyOwned(std::move(maybe_proxy).value()); } """ else: create_proxy = "" if k is SchemaKind.functional: assert self.backend_index.dispatch_key in ( DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.CompositeExplicitAutogradNonFunctional, ) return f"""{maybe_set_guard_line} outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); check_inplace(out, sizes, options); {create_proxy}""" elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); resize_out(out, sizes, strides, options); {create_proxy}""" elif k is SchemaKind.mutable or k is SchemaKind.scratch: raise AssertionError( f"{k} structured operators are currently not supported" ) else: assert_never(k) # returns the definition of a ctor, as well as how to construct # this class to a variable named op def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: if k is SchemaKind.functional: return "" elif k is SchemaKind.inplace: # TODO: Make sure out argument is guaranteed to be self return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" elif k is SchemaKind.out: out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" elif k is SchemaKind.mutable or k is SchemaKind.scratch: raise AssertionError( f"{k} structured operators are currently not supported" ) else: assert_never(k) def gen_class( self, f: NativeFunction, k: SchemaKind, *, class_name: str, parent_class: str, generate_super: bool, ) -> str: if k is SchemaKind.functional: output_type = "c10::ExclusivelyOwned" output_value = "*outputs_[output_idx]" proxy_field = "" elif k is SchemaKind.inplace: output_type = "std::reference_wrapper" output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()" proxy_field = f"std::array>, {len(f.func.returns)}> proxy_outputs_;" elif k is SchemaKind.out: output_type = "std::reference_wrapper" output_value = "proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get()" proxy_field = f"std::array>, {len(f.func.returns)}> proxy_outputs_;" if self.backend_index.dispatch_key == DispatchKey.CUDA: if self.rocm: guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" else: guard_field = "c10::cuda::OptionalCUDAGuard guard_;" elif ( self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional ): guard_field = "c10::OptionalDeviceGuard guard_;" elif self.backend_index.dispatch_key == DispatchKey.MPS: # TODO: Move to OptionalMPSGuard. guard_field = "c10::OptionalDeviceGuard guard_;" else: guard_field = "" indent = " " * 4 class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns)) lines = ( f"struct {class_name} final : public {parent_class} {{", f"{textwrap.indent(class_ctor_str, indent)}", f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", " const Tensor& maybe_get_output(int64_t output_idx) override {", f" return {output_value};\n", " }", f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", f"{textwrap.indent(proxy_field, indent)}", f"{textwrap.indent(guard_field, indent)}", "};", ) return "\n".join(line for line in lines if line) @method_with_native_function def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if ( self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f) ): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if ( self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional and f.func.kind() is SchemaKind.out ): # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False ) # Signature of the wrapper function we'll register to the dispatcher kern = self.backend_index.get_kernel(f) sig = NativeSignature( f.func, prefix="wrapper_", symint=kern is not None and kern.supports_symint(), ) if self.target is Target.NAMESPACED_DECLARATION: result = "" for cpp_sig in cpp_sig_group.signatures(symint=self.symint): result += f"TORCH_API {cpp_sig.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = "" for cpp_sig in cpp_sig_group.signatures(symint=self.symint): result += generate_defn(cpp_sig) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif ( self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutogradNonFunctional ): # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}" if self.backend_index.device_guard: device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional ) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name() ) ) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ", ".join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ", ".join( e.expr for e in translate( context, structured.meta_arguments(self.g), method=False ) ) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. precomputed_values = [ *self.g.out.precomputed.replace.values(), self.g.out.precomputed.add, ] for precomputed_elems in precomputed_values: for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), ) ) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type if k is SchemaKind.out: expr = f"op.maybe_get_output({i})" else: maybe_star = "*" if k is SchemaKind.functional else "" expr = f"{maybe_star}op.outputs_[{i}]" context.append( Expr( expr=expr, # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType( out_arg.nctype.name, MutRefCType(BaseCType(tensorT)) ), ) ) # With the expanded context, do the impl call (if not a meta # function) if ( self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional ): # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding ) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ", ".join( e.expr for e in translate(context, out_sig.arguments(), method=False) ) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ", ".join( e.expr for e in translate( context, structured.impl_arguments(self.g), method=False ) ) sig_body.append(f"op.impl({impl_exprs});") # Go over each output, and check if there is a proxy created for it. # If so, copy it over to the original output. if k is SchemaKind.out or k is SchemaKind.inplace: for i in range(len(f.func.returns)): sig_body.append( f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(**op.proxy_outputs_[{i}]);" ) # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ", ".join( f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns)) ) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ", ".join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None