# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from types import FunctionType, BuiltinMethodType, MethodDescriptorType, WrapperDescriptorType, GetSetDescriptorType from functorch._C import dim as _C _wrap_method = _C._wrap_method FUNC_TYPES = (FunctionType, MethodDescriptorType, BuiltinMethodType, WrapperDescriptorType) PROPERTY_TYPES = (GetSetDescriptorType, property) def _py_wrap_method(orig, __torch_function__): def impl(*args, **kwargs): return __torch_function__(orig, None, args, kwargs) return impl def wrap_type(use_c, to_patch, pattern, __torch_function__): if use_c: wrap_method = _wrap_method else: wrap_method = _py_wrap_method all = {} for t in reversed(pattern.mro()[:-1]): # skip object all.update(t.__dict__) def wrap_attr(orig): return property(wrap_method(orig.__get__, __torch_function__)) for name, obj in all.items(): if name in ('__dict__', '__new__', '__init__', '__repr__', '__weakref__', '__doc__', '__module__', '__dir__'): continue # skip things that have been overloaded # things that come from object like `__eq__` still need to be patched, however. if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(object, name, None): continue if isinstance(obj, FUNC_TYPES): setattr(to_patch, name, wrap_method(obj, __torch_function__)) elif isinstance(obj, PROPERTY_TYPES): setattr(to_patch, name, wrap_attr(obj))