#pragma once #include #include namespace at { class TORCH_API OptionalTensorRef { public: OptionalTensorRef() = default; ~OptionalTensorRef() { ref_.unsafeReleaseTensorImpl(); } OptionalTensorRef(const TensorBase& src) : ref_(Tensor::unsafe_borrow_t{}, src) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); } OptionalTensorRef(const OptionalTensorRef& rhs) : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} OptionalTensorRef& operator=(OptionalTensorRef rhs) { std::swap(ref_, rhs.ref_); return *this; } bool has_value() const { return ref_.defined(); } const Tensor& getTensorRef() const & { return ref_; } const Tensor& operator*() const & { return ref_; } const Tensor* operator->() const & { return &ref_; } operator bool() const { return ref_.defined(); } private: Tensor ref_; }; template auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t { // Return the grad argument in case of a hook with void return type to have an // std::function with Tensor return type static_assert(std::is_same::value, "Expected hook to return void"); return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { OptionalTensorRef grad(grad_base); fn(*grad); return Tensor(); }); } template auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t { return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { OptionalTensorRef grad(grad_base); Tensor ret = fn(*grad); return TensorBase(std::move(ret)); }); } } // namespace at