import inspect from functools import wraps from typing import Any, Callable, Optional, Type, Union, get_type_hints from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe from torch.utils.data.datapipes._typing import _DataPipeMeta ###################################################### # Functional API ###################################################### class functional_datapipe(object): name: str def __init__(self, name: str, enable_df_api_tracing=False) -> None: """ Args: enable_df_api_tracing - if set, any returned DataPipe would accept DataFrames API in tracing mode. """ self.name = name self.enable_df_api_tracing = enable_df_api_tracing def __call__(self, cls): if issubclass(cls, IterDataPipe): if isinstance(cls, Type): # type: ignore[arg-type] if not isinstance(cls, _DataPipeMeta): raise TypeError('`functional_datapipe` can only decorate IterDataPipe') # with non_deterministic decorator else: if not isinstance(cls, non_deterministic) and \ not (hasattr(cls, '__self__') and isinstance(cls.__self__, non_deterministic)): raise TypeError('`functional_datapipe` can only decorate IterDataPipe') IterDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing) elif issubclass(cls, MapDataPipe): MapDataPipe.register_datapipe_as_function(self.name, cls) return cls ###################################################### # Determinism ###################################################### _determinism: bool = False class guaranteed_datapipes_determinism(object): prev: bool def __init__(self) -> None: global _determinism self.prev = _determinism _determinism = True def __enter__(self) -> None: pass def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: global _determinism _determinism = self.prev class non_deterministic(object): cls: Optional[Type[IterDataPipe]] = None # TODO: Lambda for picking deterministic_fn: Callable[[], bool] def __init__(self, arg: Union[Type[IterDataPipe], Callable[[], bool]]) -> None: # 1. Decorator doesn't have any argument if isinstance(arg, Type): # type: ignore[arg-type] if not issubclass(arg, IterDataPipe): # type: ignore[arg-type] raise TypeError("Only `IterDataPipe` can be decorated with `non_deterministic`" ", but {} is found".format(arg.__name__)) self.cls = arg # type: ignore[assignment] # 2. Decorator has an argument of a function # This class should behave differently given different inputs. Use this # function to verify the determinism for each instance. # When the function returns True, the instance is non-deterministic. Otherwise, # the instance is a deterministic DataPipe. elif isinstance(arg, Callable): # type:ignore[arg-type] self.deterministic_fn = arg # type: ignore[assignment, misc] else: raise TypeError("{} can not be decorated by non_deterministic".format(arg)) def __call__(self, *args, **kwargs): global _determinism # Decorate IterDataPipe if self.cls is not None: if _determinism: raise TypeError("{} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " "You can turn off determinism for this DataPipe if that is acceptable " "for your application".format(self.cls.__name__)) return self.cls(*args, **kwargs) # type: ignore[call-arg] # Decorate with a functional argument if not (isinstance(args[0], Type) and # type: ignore[arg-type] issubclass(args[0], IterDataPipe)): raise TypeError("Only `IterDataPipe` can be decorated, but {} is found" .format(args[0].__name__)) self.cls = args[0] return self.deterministic_wrapper_fn def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: res = self.deterministic_fn(*args, **kwargs) # type: ignore[call-arg, misc] if not isinstance(res, bool): raise TypeError("deterministic_fn of `non_deterministic` decorator is required " "to return a boolean value, but {} is found".format(type(res))) global _determinism if _determinism and res: raise TypeError("{} is non-deterministic with the inputs, but you set " "'guaranteed_datapipes_determinism'. You can turn off determinism " "for this DataPipe if that is acceptable for your application" .format(self.cls.__name__)) # type: ignore[union-attr] return self.cls(*args, **kwargs) # type: ignore[call-arg, misc] ###################################################### # Type validation ###################################################### # Validate each argument of DataPipe with hint as a subtype of the hint. def argument_validation(f): signature = inspect.signature(f) hints = get_type_hints(f) @wraps(f) def wrapper(*args, **kwargs): bound = signature.bind(*args, **kwargs) for argument_name, value in bound.arguments.items(): if argument_name in hints and isinstance(hints[argument_name], _DataPipeMeta): hint = hints[argument_name] if not isinstance(value, IterDataPipe): raise TypeError("Expected argument '{}' as a IterDataPipe, but found {}" .format(argument_name, type(value))) if not value.type.issubtype(hint.type): raise TypeError("Expected type of argument '{}' as a subtype of " "hint {}, but found {}" .format(argument_name, hint.type, value.type)) return f(*args, **kwargs) return wrapper # Default value is True _runtime_validation_enabled: bool = True class runtime_validation_disabled(object): prev: bool def __init__(self) -> None: global _runtime_validation_enabled self.prev = _runtime_validation_enabled _runtime_validation_enabled = False def __enter__(self) -> None: pass def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: global _runtime_validation_enabled _runtime_validation_enabled = self.prev # Runtime checking # Validate output data is subtype of return hint def runtime_validation(f): # TODO: # Can be extended to validate '__getitem__' and nonblocking if f.__name__ != '__iter__': raise TypeError("Can not decorate function {} with 'runtime_validation'" .format(f.__name__)) @wraps(f) def wrapper(self): global _runtime_validation_enabled if not _runtime_validation_enabled: yield from f(self) else: it = f(self) for d in it: if not self.type.issubtype_of_instance(d): raise RuntimeError("Expected an instance as subtype of {}, but found {}({})" .format(self.type, d, type(d))) yield d return wrapper