"""An internal wrapper for the beartype library. The module returns a no-op decorator when the beartype library is not installed. """ import functools import traceback import typing import warnings from types import ModuleType from torch.onnx import _exporter_states, errors from torch.onnx._globals import GLOBALS try: import beartype as _beartype_lib # type: ignore[import] from beartype import roar as _roar # type: ignore[import] # Beartype warns when we import from typing because the types are deprecated # in Python 3.9. But there will be a long time until we can move to using # the native container types for type annotations (when 3.9 is the lowest # supported version). So we silence the warning. warnings.filterwarnings( "ignore", category=_roar.BeartypeDecorHintPep585DeprecationWarning, ) except ImportError: _beartype_lib = None # type: ignore[assignment] except Exception as e: # Warn errors that are not import errors (unexpected). warnings.warn(f"{e}") _beartype_lib = None # type: ignore[assignment] def _no_op_decorator(func): return func def _create_beartype_decorator( runtime_check_state: _exporter_states.RuntimeTypeCheckState, ): # beartype needs to be imported outside of the function and aliased because # this module overwrites the name "beartype". if runtime_check_state == _exporter_states.RuntimeTypeCheckState.DISABLED: return _no_op_decorator if _beartype_lib is None: # If the beartype library is not installed, return a no-op decorator if runtime_check_state == _exporter_states.RuntimeTypeCheckState.ERRORS: warnings.warn( "TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK is set to 'ERRORS', " "but the beartype library is not installed. " "Install beartype with `pip install beartype` to enable runtime type checking." ) return _no_op_decorator assert isinstance(_beartype_lib, ModuleType) if runtime_check_state == _exporter_states.RuntimeTypeCheckState.ERRORS: # Enable runtime type checking which errors on any type hint violation. return _beartype_lib.beartype # Warnings only def beartype(func): """Warn on type hint violation.""" if "return" in func.__annotations__: # Remove the return type from the func function's # annotations so that the beartype decorator does not complain # about the return type. return_type = func.__annotations__["return"] del func.__annotations__["return"] beartyped = _beartype_lib.beartype(func) # Restore the return type to the func function's annotations func.__annotations__["return"] = return_type else: beartyped = _beartype_lib.beartype(func) @functools.wraps(func) def _coerce_beartype_exceptions_to_warnings(*args, **kwargs): try: return beartyped(*args, **kwargs) except _roar.BeartypeCallHintParamViolation: # Fall back to the original function if the beartype hint is violated. warnings.warn( traceback.format_exc(), category=errors.CallHintViolationWarning, stacklevel=2, ) return func(*args, **kwargs) # noqa: B012 return _coerce_beartype_exceptions_to_warnings return beartype if typing.TYPE_CHECKING: # This is a hack to make mypy play nicely with the beartype decorator. def beartype(func): return func else: beartype = _create_beartype_decorator(GLOBALS.runtime_type_check_state) # Make sure that the beartype decorator is enabled whichever path we took. assert beartype is not None