#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { struct Function; struct CompilationUnit; } // namespace jit TORCH_API bool isCustomClass(const c10::IValue& v); } // namespace torch namespace c10 { struct IValue; struct ClassType; struct TupleType; struct EnumType; struct InferredType; // For custom class __init__ registration, we need to pass in a function // that looks like this: [](IValue x, args...) // However, make_boxed_from_unboxed_functor.h automatically sets the input types // of the function by introspecting the types of the functor (which is IValue in // this case). However, we need the type it binds to be Foo. // Instead, we pass in a lambda [](ivalue_holder x, args...) from // which getTypePtr can recover the original class pointer. template struct tagged_capsule { IValue ivalue; }; template c10::intrusive_ptr IValue::moveToIntrusivePtr() { auto t = c10::intrusive_ptr::reclaim( payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() ? NullType::singleton() : static_cast(payload.u.as_intrusive_ptr)); clearToNone(); return t; } template c10::intrusive_ptr IValue::toIntrusivePtr() const { if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { return c10::intrusive_ptr(); } c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); return c10::intrusive_ptr::reclaim( static_cast(payload.u.as_intrusive_ptr)); } template intrusive_ptr static_intrusive_pointer_cast(intrusive_ptr r) { return intrusive_ptr::reclaim(static_cast(r.release())); } template intrusive_ptr dynamic_intrusive_pointer_cast(intrusive_ptr r) { return intrusive_ptr::reclaim(dynamic_cast(r.release())); } inline c10::intrusive_ptr IValue::toFuture() && { AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toFuture() const& { AT_ASSERT(isFuture(), "Expected Future but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue::toRRef() && { AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toRRef() const& { AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue::toQuantizer() && { AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toQuantizer() const& { AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue::toString() && { AT_ASSERT(isString(), "Expected String but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toString() const& { AT_ASSERT(isString(), "Expected String but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue::toObject() && { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toObject() const& { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue:: toPyObjectHolder() && { TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toPyObjectHolder() const& { TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind()); return toIntrusivePtr(); } inline c10::intrusive_ptr IValue::toEnumHolder() && { TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toEnumHolder() const& { TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind()); return toIntrusivePtr(); } inline c10::complex IValue::toComplexDouble() const { TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind()); auto ptr = toIntrusivePtr(); return (*ptr).val; } inline at::Tensor IValue::toTensor() && { if (C10_UNLIKELY(!isTensor())) { reportToTensorTypeError(); } auto result = std::move(payload.as_tensor); // As far as I can tell, omitting the usual explicit destructor call // is not UB in and of itself, and it's a slight perf win. The // destructor is a no-op, because the moved-from Tensor is // effectively an intrusive_ptr in the null state, so we don't need // the behavior for correctness reasons either. Leaving this // explanatory comment, including commented-out destructor call, to // make this abundantly clear. // // payload.as_tensor.~Tensor(); clearToNone(); return result; } inline at::Tensor& IValue::toTensor() & { if (C10_UNLIKELY(!isTensor())) { reportToTensorTypeError(); } return payload.as_tensor; } inline const at::Tensor& IValue::toTensor() const& { if (C10_UNLIKELY(!isTensor())) { reportToTensorTypeError(); } return payload.as_tensor; } inline c10::Storage IValue::toStorage() && { AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); return c10::Storage( moveToIntrusivePtr()); } inline c10::Storage IValue::toStorage() const& { AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); return c10::Storage(toIntrusivePtr()); } inline c10::Stream IValue::toStream() && { return c10::Stream::unpack(payload.u.as_int); } inline c10::Stream IValue::toStream() const& { return c10::Stream::unpack(payload.u.as_int); } inline c10::intrusive_ptr IValue::toBlob() && { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toBlob() const& { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); return toIntrusivePtr(); ; } inline c10::intrusive_ptr IValue::toCapsule() && { TORCH_INTERNAL_ASSERT(isCapsule()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toCapsule() const& { TORCH_INTERNAL_ASSERT(isCapsule()); return toIntrusivePtr(); } inline at::Generator IValue::toGenerator() && { AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); return at::Generator(moveToIntrusivePtr()); } inline at::Generator IValue::toGenerator() const& { AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind()); return at::Generator(toIntrusivePtr()); } inline c10::SymInt IValue::toSymInt() const { AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); if (isSymInt()) { return c10::SymInt::toSymInt(toIntrusivePtr()); } else { return c10::SymInt(payload.u.as_int); } } inline c10::SymFloat IValue::toSymFloat() const { AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); if (isSymFloat()) { return c10::SymFloat::toSymFloat(toIntrusivePtr()); } else { return c10::SymFloat(payload.u.as_double); } } namespace ivalue { void TORCH_API checkCustomClassType(const ClassType* expected_type, const Type* actual_type); template using Shared = c10::intrusive_ptr; // string struct TORCH_API ConstantString final : c10::intrusive_ptr_target { private: const std::string str_; public: ConstantString(std::string str) : str_(std::move(str)) {} ConstantString(c10::string_view str) : str_(std::string(str)) {} static c10::intrusive_ptr create(std::string str_); static c10::intrusive_ptr create(c10::string_view str_); static c10::intrusive_ptr create(const char* str_); const std::string& string() const { return str_; } c10::string_view string_view() const { return str_; } operator const std::string&() const { return string(); } TORCH_API friend std::ostream& operator<<( std::ostream& out, const ConstantString& v); }; struct Future; struct TORCH_API TupleElements { private: size_t inlineSize_; // We represent TupleElements this way to save doing a heap // allocation in the common (at least for unpickling) case where we // have only 3 elements. We have our own union instead of // c10::SmallVector because c10::SmallVector always // stores the begin/end/capacity pointers, which would be a waste of // space in our use case. union { std::vector elementsVector_; // Don't want to declare a std::array because the convenient // iteration and size members are a footgun in this case -- the // actual size of the array may be smaller than 3! // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) IValue elementsInline_[3]; }; void destroyInline() { for (const auto ii : c10::irange(inlineSize_)) { elementsInline_[ii].~IValue(); } } public: using iterator = IValue*; using const_iterator = const IValue*; TupleElements() : inlineSize_(0) { new (&elementsVector_) std::vector(); } explicit TupleElements(std::vector elements) : inlineSize_(0), elementsVector_(std::move(elements)) {} explicit TupleElements(c10::ArrayRef elements) : inlineSize_(elements.size() <= 3 ? elements.size() : 0) { switch (inlineSize_) { case 3: new (&elementsInline_[2]) IValue(elements[2]); C10_FALLTHROUGH; case 2: new (&elementsInline_[1]) IValue(elements[1]); C10_FALLTHROUGH; case 1: new (&elementsInline_[0]) IValue(elements[0]); break; case 0: new (&elementsVector_) std::vector(elements.begin(), elements.end()); break; } } explicit TupleElements(IValue&& e1) : inlineSize_(1) { new (&elementsInline_[0]) IValue(std::move(e1)); } explicit TupleElements(IValue&& e1, IValue&& e2) : inlineSize_(2) { new (&elementsInline_[0]) IValue(std::move(e1)); new (&elementsInline_[1]) IValue(std::move(e2)); } explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3) : inlineSize_(3) { new (&elementsInline_[0]) IValue(std::move(e1)); new (&elementsInline_[1]) IValue(std::move(e2)); new (&elementsInline_[2]) IValue(std::move(e3)); } ~TupleElements() { if (inlineSize_) { destroyInline(); } else { elementsVector_.~vector(); } } // It would be nice to make this noncopyable to prevent people from // writing code like `auto output = // forward(...).toTupleRef().elements()` (which does refcount bumps on // each element, unlike the more efficient but verbose // ``` // auto outputIntrusivePtr = forward(...).toTuple(); // const auto& output = outputIntrusivePtr->elements(); // ``` // ), but there is simply an overwhelming amount of code that does // it the inefficient way. // See also operator std::vector below. TupleElements(const TupleElements& rhs) : inlineSize_(rhs.inlineSize_) { if (rhs.inlineSize_) { for (const auto ii : c10::irange(inlineSize_)) { new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); } } else { new (&elementsVector_) std::vector(rhs.elementsVector_); } } TupleElements& operator=(const TupleElements& rhs) { if (inlineSize_) { if (rhs.inlineSize_) { for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { elementsInline_[ii] = rhs.elementsInline_[ii]; } if (rhs.inlineSize_ > inlineSize_) { for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); } } else { for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { elementsInline_[ii].~IValue(); } } } else { destroyInline(); new (&elementsVector_) std::vector(rhs.elementsVector_); } } else { if (rhs.inlineSize_) { elementsVector_.~vector(); for (const auto ii : c10::irange(rhs.inlineSize_)) { new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); } } else { elementsVector_ = rhs.elementsVector_; } } inlineSize_ = rhs.inlineSize_; return *this; } TupleElements(TupleElements&& rhs) noexcept : inlineSize_(rhs.inlineSize_) { if (inlineSize_) { for (const auto ii : c10::irange(inlineSize_)) { new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); } } else { new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); } } TupleElements& operator=(TupleElements&& rhs) noexcept { if (inlineSize_) { if (rhs.inlineSize_) { for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { elementsInline_[ii] = std::move(rhs.elementsInline_[ii]); } if (rhs.inlineSize_ > inlineSize_) { for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); } } else { for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { elementsInline_[ii].~IValue(); } } } else { destroyInline(); new (&elementsVector_) std::vector(std::move(rhs.elementsVector_)); } } else { if (rhs.inlineSize_) { elementsVector_.~vector(); for (const auto ii : c10::irange(rhs.inlineSize_)) { new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); } } else { elementsVector_ = std::move(rhs.elementsVector_); } } inlineSize_ = rhs.inlineSize_; return *this; } C10_NODISCARD c10::ArrayRef asArrayRef() const { if (inlineSize_) { return c10::ArrayRef(elementsInline_, inlineSize_); } else { return elementsVector_; } } // Mimic implicit conversion from std::vector to ArrayRef. operator c10::ArrayRef() const { return asArrayRef(); } static size_t hash(const TupleElements& v) { return c10::hash>()(v.asArrayRef()); } void setContents(std::vector&& contents) { if (inlineSize_) { destroyInline(); new (&elementsVector_) std::vector(std::move(contents)); inlineSize_ = 0; } else { elementsVector_ = std::move(contents); } } C10_NODISCARD bool empty() const { return inlineSize_ ? false : elementsVector_.empty(); } C10_NODISCARD size_t size() const { return inlineSize_ ? inlineSize_ : elementsVector_.size(); } C10_NODISCARD IValue& operator[](size_t idx) { if (inlineSize_) { return elementsInline_[idx]; } else { return elementsVector_[idx]; } } C10_NODISCARD const IValue& operator[](size_t idx) const { if (inlineSize_) { return elementsInline_[idx]; } else { return elementsVector_[idx]; } } C10_NODISCARD IValue& at(size_t idx) { if (inlineSize_) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); return elementsInline_[idx]; } else { return elementsVector_.at(idx); } } C10_NODISCARD const IValue& at(size_t idx) const { if (inlineSize_) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); return elementsInline_[idx]; } else { TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size()); return elementsVector_.at(idx); } } C10_NODISCARD iterator begin() { if (inlineSize_) { return elementsInline_; } else { return elementsVector_.data(); } } C10_NODISCARD iterator end() { if (inlineSize_) { return elementsInline_ + inlineSize_; } else { return elementsVector_.data() + elementsVector_.size(); } } C10_NODISCARD const_iterator begin() const { if (inlineSize_) { return elementsInline_; } else { return elementsVector_.data(); } } C10_NODISCARD const_iterator end() const { if (inlineSize_) { return elementsInline_ + inlineSize_; } else { return elementsVector_.data() + elementsVector_.size(); } } C10_NODISCARD const_iterator cbegin() const { return begin(); } C10_NODISCARD const_iterator cend() const { return end(); } C10_NODISCARD std::vector vec() const & { return asArrayRef().vec(); } C10_NODISCARD IValue& back() { return *(end() - 1); } C10_NODISCARD const IValue& back() const { return *(end() - 1); } C10_NODISCARD std::vector vec() && { std::vector result; result.reserve(size()); for (auto&& iv : *this) { result.push_back(std::move(iv)); } return result; } // More compatibility shims for the overwhelming amount of code that // likes to copy tuple elements into a vector; see comment above the // copy constructor. operator std::vector() const & { return vec(); } operator std::vector() && { return vec(); } }; template struct TupleTypeFactory {}; template <> struct TORCH_API TupleTypeFactory { static TupleTypePtr create(std::vector types) { return TupleType::create(std::move(types)); } static TupleTypePtr fallback(const Type& type); }; template <> struct TORCH_API TupleTypeFactory { static DynamicTypePtr create(std::vector elemTypes); static DynamicTypePtr fallback(const Type&); }; struct TORCH_API Tuple : c10::intrusive_ptr_target { private: TupleElements elements_; mutable c10::TypePtr type_; // lazily computed for unnamed tuples public: // named tuples have additional type information, so we // directly create them tagged static c10::intrusive_ptr createNamed( std::vector elements_, c10::TypePtr type_) { return c10::make_intrusive(std::move(elements_), std::move(type_)); } static c10::intrusive_ptr createNamed( TupleElements elements_, std::shared_ptr type_) { return c10::make_intrusive(std::move(elements_), std::move(type_)); } static c10::intrusive_ptr createNamed( std::initializer_list elements_, std::shared_ptr type_) { return createNamed(TupleElements(c10::ArrayRef(elements_)), std::move(type_)); } // MSVC apparently can't disambiguate the other two overloads of // create when passed an initializer_list without this. static c10::intrusive_ptr create(std::initializer_list elements_) { return create(c10::ArrayRef(elements_)); } static c10::intrusive_ptr create(std::vector elements_) { return c10::make_intrusive(std::move(elements_)); } static c10::intrusive_ptr create(TupleElements elements_) { return c10::make_intrusive(std::move(elements_)); } static c10::intrusive_ptr create(c10::ArrayRef elements_) { return create(TupleElements(elements_)); } static c10::intrusive_ptr create(IValue e1) { return c10::make_intrusive(std::move(e1)); } static c10::intrusive_ptr create(IValue e1, IValue e2) { return c10::make_intrusive(std::move(e1), std::move(e2)); } static c10::intrusive_ptr create(IValue e1, IValue e2, IValue e3) { return c10::make_intrusive(std::move(e1), std::move(e2), std::move(e3)); } private: // Workaround inability to use `>` operator in template argument list. template static constexpr bool hasMoreThanThreeArgs() { return sizeof...(Args) > 3; } public: template static c10::intrusive_ptr create(Args&&... elements_) { switch (sizeof...(Args)) { case 1: case 2: case 3: return create(IValue(std::forward(elements_))...); default: return create( std::vector{IValue(std::forward(elements_))...}); } } // Again, it would be nice to make this noncopyable, but there's a // lot of extant code that copies Tuples. // Tuple(const Tuple& rhs) = delete; const TupleElements& elements() const& { return elements_; } TupleElements elements() && { return std::move(elements_); } void setElements(std::vector&& elements) { elements_.setContents(std::move(elements)); } void setElements(TupleElements&& elements) { elements_ = std::move(elements); } void unsafeSetElement(size_t idx, const IValue& element) { elements_[idx] = element; } void unsafeSetElement(size_t idx, IValue&& element) { elements_[idx] = std::move(element); } size_t size() const { return elements_.size(); } template std::shared_ptr type() const { if (!type_) { type_ = TupleTypeFactory::create(fmap(elements(), [&](const IValue& v) { return v.type(); })); } if (auto t = type_->cast()) { return t; } return TupleTypeFactory::fallback(*type_); } static size_t hash(const Tuple& t) { return c10::get_hash(t.elements()); } TORCH_API friend bool operator==( const ivalue::Tuple& lhs, const ivalue::Tuple& rhs); private: // NOTE: If we try to avoid the overloads without // `std::shared_ptr type` by defaulting it to nullptr, we // end up having to call (part of) the shared_ptr destructor for // `type` even though we should know statically it won't do // anything. explicit Tuple(std::vector elements) : elements_(std::move(elements)){} explicit Tuple(std::vector elements, c10::TypePtr type) : elements_(std::move(elements)), type_(std::move(type)) {} explicit Tuple(TupleElements&& elements) : elements_(std::move(elements)) {} explicit Tuple(TupleElements&& elements, std::shared_ptr type) : elements_(std::move(elements)), type_(std::move(type)) {} explicit Tuple(IValue&& e1) : elements_(std::move(e1)) {} explicit Tuple(IValue&& e1, std::shared_ptr type) : elements_(std::move(e1)), type_(std::move(type)) {} explicit Tuple(IValue&& e1, IValue&& e2) : elements_(std::move(e1), std::move(e2)) {} explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr type) : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {} explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3) : elements_(std::move(e1), std::move(e2), std::move(e3)) {} explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr type) : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {} friend class c10::intrusive_ptr; }; struct Object; struct PyObjectHolder; struct EnumHolder; } // namespace ivalue // Future struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { private: // Keep this private in order to force users to go through make_intrusive and // thus prevent creating a Future that's not held by an intrusive_ptr. explicit Future(TypePtr type, std::vector devices={}) : type_(std::move(type)), impl_(getTypeOfDevices(devices)), devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {} friend c10::intrusive_ptr; public: Future(const Future&) = delete; Future(Future&&) = delete; Future& operator=(const Future&) = delete; Future& operator=(Future&&) = delete; struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} FutureError() = default; const char* what() const noexcept override { return error_msg.c_str(); } std::string error_msg; }; /** * Wait on the future until it completes. */ void wait() { std::unique_lock lock(mutex_); finished_cv_.wait(lock, [&]() -> bool { return completed_; }); synchronizeWithCurrentStreams(); } /** * Wait on the future until it completes and throw an * exception if an error exists. */ void waitAndThrow() { wait(); if (eptr_) { std::rethrow_exception(eptr_); } } /** * Explicitly mark the future as completed with the output value. Optionally, * the storages for all tensors in IValue can be passed as well. The DataPtrs * of these storages are used to synchronize CUDA streams. If storages isn't * given we will attempt to extract it from the value, if we need to (this * happens if a non-empty set of devices was given to the constructor). Thus * one only needs to provide storages when 1) they cannot be extracted through * IValue::getSubValues() or through pickling in case of Python object; or * when 2) customized storage extraction is more efficient. */ using WeakStorage = c10::weak_intrusive_ptr; void markCompleted( IValue value, c10::optional> storages = c10::nullopt) { // Start by performing all steps that can throw, before setting any field. // Do this before even acquiring the mutex, because extractStorages might // acquire the GIL, which could lead to a lock inversion with our mutex. // See https://github.com/pytorch/pytorch/issues/58239. std::vector actualStorages; std::vector usedDevices; try { // FIXME We should always extract DataPtrs, in order to catch the case of // users using CUDA values but forgetting to set devices, which currently // leads to a silent synchronization/correctness issue. However, as this // might worsen perf in CPU-only cases, we should only do so after careful // benchmarks. if (impl_.type() != c10::kCPU) { actualStorages = storages.has_value() ? std::move(*storages) : extractStorages(value); usedDevices = getDevicesOfStorages(impl_, actualStorages); ensureIsSubsetOfDevices(usedDevices, devices_); } } catch (const std::exception&) { setError(std::current_exception()); return; } std::unique_lock lock(mutex_); TORCH_CHECK( !completed(), "Attempting to mark a completed Future as complete again. Note that " "a Future can only be marked completed once."); // Only set value_ and completed_ flag once all checks and preparation steps // have returned successfully to allow for proper error propagation. value_ = std::move(value); completed_ = true; currentDevice_ = impl_.getDevice(); storages_ = std::move(actualStorages); for (const c10::Device& device : usedDevices) { c10::Event event(impl_.type()); event.record(impl_.getStream(device)); events_.push_back(std::move(event)); } std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); finished_cv_.notify_all(); for (auto& callback : cbs) { invokeCallback(std::move(callback)); } } void markCompleted() { markCompleted(IValue{}); } void setError(std::exception_ptr eptr) { std::unique_lock lock(mutex_); setErrorInternal(std::move(eptr), lock); } void setErrorIfNeeded(std::exception_ptr eptr) { std::unique_lock lock(mutex_); if (completed_) { // This should be rare and shouldn't cause log spew. Its important to // log errors and thats why we have this log here. std::string msg = c10::str( "Skipping setting following error on the Future since " "it is already marked completed (this is not necessarily " "an error):\n", tryRetrieveErrorMessageInternal(eptr)); if (eptr_) { msg += c10::str( ", \nOriginal exception:\n", tryRetrieveErrorMessageInternal(eptr_)); } LOG(INFO) << msg; return; } else { setErrorInternal(std::move(eptr), lock); } } // Get the result of the current future. IValue value() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); if (eptr_) { std::rethrow_exception(eptr_); } return value_; } // This accessor should only be used if we know that the future is // completed() with no error. const IValue& constValue() const { std::unique_lock lock(mutex_); AT_ASSERT(completed()); TORCH_INTERNAL_ASSERT( !eptr_, "value() accessor should only be used when future is not completed with ", "an error, but future had the following error: ", tryRetrieveErrorMessageInternal(eptr_) ); return value_; } // This accessor should only be used if we know that the future is // completed() with no error. const std::vector& storages() const { std::unique_lock lock(mutex_); AT_ASSERT(completed()); AT_ASSERT(!eptr_); return storages_; } /** * Add a callback to the future. * The callbacks will be executed once the future completes. * If the future has already completed, * this function will execute the callback immediately. */ template void addCallback(T callback) { #if __cpp_lib_is_invocable >= 201703 static_assert( std::is_invocable_r::value, "The callback must have signature void(Future&)"); #endif std::unique_lock lock(mutex_); if (completed()) { lock.unlock(); invokeCallback(std::move(callback)); return; } callbacks_.emplace_back(std::move(callback)); } /** * Add a callback to the future, and return another Future to hold the return * value of the callback. This is necessary when the callback provider needs * to know for sure when the callback has finished. */ template c10::intrusive_ptr then(T callback, TypePtr type) { using IValueWithStorages = std::tuple>; #if __cpp_lib_is_invocable >= 201703 static_assert( guts::disjunction< std::is_invocable_r, std::is_invocable_r>::value, "The callback must have signature IValue(Future&) or " "std::tuple>(Future&)"); #endif auto childFut = createInstance(std::move(type)); addCallback([childFut, cb = std::move(callback)](Future& parentFut) mutable { try { guts::if_constexpr, IValueWithStorages>::value>( [&](auto identity) { IValue value; std::vector storages; std::tie(value, storages) = identity(cb)(parentFut); childFut->markCompleted(std::move(value), std::move(storages)); }, [&](auto identity) { childFut->markCompleted(identity(cb)(parentFut)); }); } catch (std::exception&) { childFut->setError(std::current_exception()); } }); return childFut; } template c10::intrusive_ptr thenAsync(T callback, TypePtr type) { #if __cpp_lib_is_invocable >= 201703 static_assert( std::is_invocable_r, T, Future&>::value, "The callback must have signature c10::intrusive_ptr(Future&)"); #endif auto childFut = createInstance(std::move(type)); addCallback( [childFut, cb = std::move(callback)](Future& parentFut) mutable { c10::intrusive_ptr intermediateFut; try { intermediateFut = cb(parentFut); } catch (std::exception&) { childFut->setError(std::current_exception()); return; } intermediateFut->addCallback( [childFut = std::move(childFut)](Future& intermediateFut) { if (intermediateFut.hasError()) { childFut->setError(intermediateFut.exception_ptr()); } else { childFut->markCompleted( intermediateFut.value(), intermediateFut.storages()); } }); }); return childFut; } // Tries to retrieve the error message from std::exception_ptr. std::string tryRetrieveErrorMessage() const { TORCH_CHECK(hasError(), "No error present on the future."); std::unique_lock lock(mutex_); return tryRetrieveErrorMessageInternal(eptr_); } // Check if the current future has completed bool completed() const { return completed_; } bool hasValue() const { std::unique_lock lock(mutex_); return completed_ && !eptr_; } bool hasError() const { std::unique_lock lock(mutex_); return eptr_ ? true : false; } std::exception_ptr exception_ptr() const { std::unique_lock lock(mutex_); return eptr_; } TORCH_API friend std::ostream& operator<<( std::ostream& out, const Future& v); TypePtr elementType() const { return type_; } const std::vector& devices() const { return devices_; } // This method should be used when one intends to manually create a child // future, for example when implementing a customized version of then(). c10::intrusive_ptr createInstance(at::TypePtr type) { return c10::make_intrusive(std::move(type), devices_); } private: // This method should always be used when invoking a callback (regardless of // how/when that happens) as it will ensure that the proper "environment" is // set up before running the callback, as in, it will set up the CUDA streams, // synchronize them with the value, and so on (if needed). template void invokeCallback(T callback) { #if __cpp_lib_is_invocable >= 201703 static_assert( std::is_invocable_r::value, "The callback must have signature void(Future&)"); #endif c10::OptionalDeviceGuard deviceGuard(currentDevice_); std::vector streams; for (const c10::Device& device : devices_) { streams.push_back(impl_.getStreamFromGlobalPool(device)); } c10::MultiStreamGuard streamGuard(streams); synchronizeWithCurrentStreams(); callback(*this); } // This method should be called before this future's value is used, as it // ensures that the CUDA streams that are "current" at the callsite properly // synchronize with the value. void synchronizeWithCurrentStreams() { for (c10::Event& event : events_) { event.block(impl_.getStream(event.device())); } for (const WeakStorage& weak_storage : storages_) { c10::intrusive_ptr storage = weak_storage.lock(); if (!storage) { continue; } if (!storage->device().is_cpu()) { impl_.recordDataPtrOnStream( storage->data_ptr(), impl_.getStream(storage->device())); } } } void setErrorInternal( std::exception_ptr eptr, std::unique_lock& lock) { TORCH_CHECK( !eptr_, "Error already set on this Future: ", tryRetrieveErrorMessageInternal(eptr_), ", trying to set error: ", tryRetrieveErrorMessageInternal(eptr)); TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed"); completed_ = true; eptr_ = std::move(eptr); std::vector> cbs; cbs.swap(callbacks_); lock.unlock(); finished_cv_.notify_all(); for (auto& callback : cbs) { invokeCallback(std::move(callback)); } } // Tries to retrieve the error message from std::exception_ptr. std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { try { std::rethrow_exception(eptr); } catch (const std::exception& e) { return e.what(); } catch (...) { return "Unknown Exception Type"; } } // Defined in ivalue.cpp. static std::vector extractStorages( const at::IValue& value); static std::vector getDevicesOfStorages( const c10::impl::VirtualGuardImpl& impl, const std::vector& storages) { c10::DeviceIndex deviceCount = impl.deviceCount(); std::vector isDeviceUsed(deviceCount, false); for (const WeakStorage& weak_storage : storages) { c10::intrusive_ptr storage = weak_storage.lock(); if (!storage) { continue; } c10::Device device = storage->device(); if (!device.is_cpu()) { TORCH_CHECK_VALUE( device.type() == impl.type(), "Expected all data ptrs to be on a device of type ", impl.type(), ", got one on device ", device); isDeviceUsed[device.index()] = true; } } std::vector devices; for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) { if (isDeviceUsed[idx]) { devices.emplace_back(impl.type(), idx); } } return devices; } static std::string formatSetOfDevices( const std::vector& devices) { if (devices.empty()) { return "(none)"; } std::ostringstream oss; oss << devices[0]; for (const auto idx : c10::irange(1, devices.size())) { if (idx == devices.size() - 1) { oss << " and "; } else { oss << ", "; } oss << devices[idx]; } return oss.str(); } static c10::DeviceType getTypeOfDevices( const std::vector& devices) { if (devices.empty()) { return c10::kCPU; } c10::DeviceType deviceType = devices[0].type(); for (const auto idx : c10::irange(1, devices.size())) { TORCH_CHECK_VALUE( devices[idx].type() == deviceType, "Expected all devices to be of the same type, but got a mismatch between ", devices[0], " and ", devices[idx]); } return deviceType; } // We need devices to be sorted in order to use ensureIsSubsetOfDevices. static std::vector sortAndDeduplicateDevices( const c10::impl::VirtualGuardImpl& /*impl*/, std::vector devices) { std::sort( devices.begin(), devices.end(), [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); // Deduplicate by compacting. size_t targetIdx = 0; for (const auto sourceIdx : c10::irange(devices.size())) { TORCH_CHECK_VALUE( devices[sourceIdx].has_index(), "Expected devices to have indices, got ", devices[sourceIdx]); if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) { // It's a duplicate, skip it. continue; } if (sourceIdx != targetIdx) { devices[targetIdx] = devices[sourceIdx]; } targetIdx++; } // If there were duplicates there's now a gap at the end: trim it. Resizing // requires the item type to be default-constructible (which c10::Device is // not) because in principle it could be required to create new items. Since // we know we'll shrink the vector, we provide a custom dummy value instead. devices.resize(targetIdx, c10::Device(c10::kCPU)); return devices; } static void ensureIsSubsetOfDevices( const std::vector& subset, const std::vector& superset) { // We assume the devices in both vectors have the same consistent type, and // their indices are unique and sorted. std::vector excessDevices; std::set_difference( subset.begin(), subset.end(), superset.begin(), superset.end(), std::back_inserter(excessDevices), [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); TORCH_CHECK_VALUE( excessDevices.empty(), "The result contained tensors residing on device(s) ", formatSetOfDevices(excessDevices), " which are not among the expected device(s) ", formatSetOfDevices(superset)); } mutable std::mutex mutex_; std::atomic_bool completed_ = {false}; // is this future complete std::condition_variable finished_cv_; IValue value_; // when finished the value TypePtr type_; std::vector> callbacks_; std::exception_ptr eptr_; // An upcast pointer to a virtual class which allows us to manipulate events, // streams, ... in a generic way, without an explicit dependency on CUDA. const c10::impl::VirtualGuardImpl impl_; // The device that was current when markCompleted was called, which we'll // restore when invoking callbacks. It's optional because we'll only store it // if the future completes successfully. optional currentDevice_; // The events that correspond to the completion of the async I/O kernels. They // are recorded on the appropriate streams when the future is marked completed // and can then be queried/waited/blocked on. There is one event for each // distinct device on which the value's tensors reside. std::vector events_; // A cached version of the storages extracted from the value when the future // is first marked completed. std::vector storages_; // The bounding set of devices that this future, and any of its children, is // allowed to use. This is a superset of the set of devices used by the events // above. We need this to know what streams (for which devices) to set as // current when invoking a callback, thus allowing the callback to use devices // that the parent future didn't use. This field is set to the value provided // in the constructor and will be "inherited" by all child futures. const std::vector devices_; }; // Input is a list of Futures with the same target type. // Output is a Future to the List of completed Futures. TORCH_API intrusive_ptr collectAll( c10::List> srcs); // Input is a List of Futures with the same target type. // Output is a Future that will be updated with a seen value. TORCH_API intrusive_ptr collectAny( c10::List> srcs); // User-defined object. struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { public: // In general, class types hold a shared_ptr to its owning CompilationUnit, // so that its type and methods do not get deallocated while the class exists. // However, the CompilationUnit holds ownership of the type's graphs, so // inserting a constant object into a Graph would create a reference cycle if // that constant object held a shared_ptr to its CU. For these objects we // instatiate them with non-owning references to its CU Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { slots_.resize(numSlots); } Object(StrongTypePtr type, size_t numSlots) : type_(WeakOrStrongTypePtr(std::move(type))) { slots_.resize(numSlots); } static c10::intrusive_ptr create( WeakOrStrongTypePtr type, size_t numSlots) { return c10::make_intrusive(std::move(type), numSlots); } static c10::intrusive_ptr create( StrongTypePtr type, size_t numSlots) { return c10::make_intrusive(std::move(type), numSlots); } static c10::intrusive_ptr create(ClassTypePtr classType, size_t numSlots); /** * Slot API. * * Attributes are stored as a simple vector so that lookups are fast at * runtime. A "slot" is just an index into that vector, which can be computed * statically if you have access to the class type. Use this API if you are * writing compiler stuff. */ void setSlot(size_t slot, IValue v) { if (slot >= slots_.size()) { // for module types, it is possible that the members of the class have // expanded after the object was created. In this case, we expand // the slots to the right size resizeObject(slot); } slots_[slot] = std::move(v); } const IValue& getSlot(size_t slot) const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size()); // NOTE: This lookup is fairly hot, so we use unchecked access to the // vector. Errors should still be detectable with ASan. return slots_[slot]; } void unsafeRemoveSlot(size_t slot) { TORCH_CHECK(slot < slots_.size()); slots_.erase(slots_.begin() + slot); } /** * Attribute API. * * Wrappers around the slot stuff so that users can access attributes * directly. Use this API if you are a user. * * Note: Unlike in Python, TorchScript must make a distinction between * attributes (which are IValues) and methods (which are Methods). If you * want a method, use `obj.type()->getMethod()` */ IValue getAttr(const std::string& name) const; void setAttr(const std::string& name, IValue v); // Remove attribute by name, caller is responsible for // the safety of this operation // We didn't remove the attribute in the type because the type // might be shared by multiple objects. // Therefore after removing attribute, the object is in an inconsistent // state where it has more attribute types in its Type than // the attribute slots it has, user needs to make sure the object // has consistent by removing the attribute in type as well void unsafeRemoveAttr(const std::string& name); std::string name() const; const std::vector& slots() const { return slots_; } std::shared_ptr type() const; std::shared_ptr compilation_unit() { if (type_.holds_strong_ref()) { return type_.cu_.getStrongRefOrThrow(); } else { auto weak_ptr = type_.cu_.getWeakRefOrThrow(); return std::shared_ptr(weak_ptr); } } c10::intrusive_ptr copy_to_weak_compilation_ref() const; void unsafe_make_weak_compilation_ref() { type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr()); } c10::intrusive_ptr copy() const; c10::intrusive_ptr deepcopy() const; c10::intrusive_ptr deepcopy(IValue::HashAliasedIValueMap& memo) const; bool is_weak_compilation_ref() const { return !type_.holds_strong_ref(); } bool is_empty_strong_compilation_ref() const { return type_.holds_empty_strong_ref(); } private: void resizeObject(size_t slot); WeakOrStrongTypePtr type_; std::vector slots_; }; // virtual ivalue PyObjectHolder that hold a py::object, we make this virtual // because the py::object and refcounting logic should happen in libtorch_python // see concrete implementation in python_ivalue.h struct ivalue::PyObjectHolder : c10::intrusive_ptr_target { public: virtual PyObject* getPyObject() = 0; virtual c10::InferredType tryToInferType() = 0; virtual IValue toIValue(const TypePtr& type, c10::optional N = c10::nullopt) = 0; virtual std::string toStr() = 0; virtual std::vector extractTensors() = 0; virtual ~PyObjectHolder(){}; }; struct ivalue::EnumHolder : c10::intrusive_ptr_target { public: EnumHolder(std::shared_ptr type, std::string name, IValue value) : type_(std::move(type)), name_(std::move(name)), value_(std::move(value)) {} bool is(const ivalue::EnumHolder& rhs) { return *this == rhs; } friend bool operator==( const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs); TORCH_API friend std::ostream& operator<<( std::ostream& out, const EnumHolder& v); TORCH_API const std::string qualifiedClassName() const; const std::string unqualifiedClassName() const; const std::string& name() const { return name_; } const IValue& value() const { return value_; } std::shared_ptr type() const { return type_; } private: std::shared_ptr type_; std::string name_; IValue value_; }; #undef TORCH_FORALL_TAGS namespace detail { struct _guarded_unsigned_long_unique_dummy final { _guarded_unsigned_long_unique_dummy(int64_t){}; }; using _guarded_unsigned_long = std::conditional_t< std::is_same::value || std::is_same::value, _guarded_unsigned_long_unique_dummy, unsigned long>; } // namespace detail inline ivalue::Object& IValue::toObjectRef() const { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); return *static_cast(payload.u.as_intrusive_ptr); } // note: when adding a DEFINE_TO case here you should also add a // toX method to IValue. These named methods are much more discoverable // than the to templated function. #define DEFINE_TO(T, method_name) \ template <> \ inline T IValue::to()&& { \ return static_cast(std::move(*this).method_name()); \ } \ template <> \ inline c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { \ typedef c10::detail::ivalue_to_const_ref_overload_return::type return_type; \ return static_cast(this->method_name()); \ } DEFINE_TO(at::Tensor, toTensor) DEFINE_TO(at::Storage, toStorage) DEFINE_TO(c10::Stream, toStream) DEFINE_TO(float, toDouble) DEFINE_TO(double, toDouble) DEFINE_TO(c10::complex, toComplexDouble) DEFINE_TO(unsigned char, toInt) DEFINE_TO(signed char, toInt) DEFINE_TO(unsigned short, toInt) DEFINE_TO(short, toInt) DEFINE_TO(int, toInt) DEFINE_TO(uint32_t, toInt) DEFINE_TO(uint64_t, toInt) DEFINE_TO(detail::_guarded_unsigned_long, toInt) DEFINE_TO(int64_t, toInt) DEFINE_TO(bool, toBool) DEFINE_TO(c10::intrusive_ptr, toBlob); DEFINE_TO(c10::intrusive_ptr, toString) DEFINE_TO(c10::intrusive_ptr, toObject) DEFINE_TO(at::Scalar, toScalar) DEFINE_TO(c10::List, toIntList) DEFINE_TO(c10::List, toDoubleList) DEFINE_TO(c10::List>, toComplexDoubleList) DEFINE_TO(c10::List, toBoolList) DEFINE_TO(c10::List, toTensorList) DEFINE_TO(c10::impl::GenericList, toList) DEFINE_TO(c10::impl::GenericDict, toGenericDict) DEFINE_TO(c10::intrusive_ptr, toTuple) DEFINE_TO(std::string, toStringRef) DEFINE_TO(c10::string_view, toStringView) DEFINE_TO(c10::intrusive_ptr, toFuture) DEFINE_TO(c10::intrusive_ptr, toRRef) DEFINE_TO(c10::intrusive_ptr, toQuantizer) DEFINE_TO(IValue, toIValue) DEFINE_TO(c10::Device, toDevice) DEFINE_TO(at::ScalarType, toScalarType) DEFINE_TO(at::Layout, toLayout) DEFINE_TO(at::MemoryFormat, toMemoryFormat) DEFINE_TO(at::QScheme, toQScheme) DEFINE_TO(at::Dimname, toDimname) DEFINE_TO(at::Generator, toGenerator) DEFINE_TO(c10::SymInt, toSymInt) DEFINE_TO(c10::SymFloat, toSymFloat) template struct _fake_type {}; // generic_to converts an IValue from a generic list or generic dict // to a concrete list/dict type likelike List, Dict<...> or optional. // Note that in the case of lists, this only works for IValue-based lists, // i.e. not for int64_t, double, ... // generic_to is an implementation detail of IValue::to and not // supposed to be called directly. // The _fake_type parameter allows us to overload // based on the return type. template // TODO this is deprecated but we don't throw a warning because a lot of ops in // native_functions.yaml still return std::vector. // C10_DEPRECATED_MESSAGE("IValues based on std::vector are potentially slow // and deprecated. Please use torch::List instead.") std::vector generic_to(IValue ivalue, _fake_type>) { // We need to do a deep copy of the vector because there might be other // references to this same IValue that also use the list. We can't just // move the elements out. auto list = std::move(ivalue).to>(); std::vector result; result.reserve(list.size()); for (Elem v : list) { result.push_back(std::move(v)); } return result; } template c10::intrusive_ptr IValue::toCustomClass() && { static_assert( std::is_base_of::value == true, "toCustomClass requires that template parameter T must inherit " "from torch::CustomClassHolder"); auto obj = toObject(); TORCH_CHECK( obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); const auto* expected_type = c10::getCustomClassType>().get(); ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; } template c10::intrusive_ptr IValue::toCustomClass() const& { static_assert( std::is_base_of::value == true, "toCustomClass requires that template parameter T must inherit " "from torch::CustomClassHolder"); auto obj = toObject(); TORCH_CHECK( obj->slots().size() == 1, "Tried to cast IValue to custom class but it did " "not contain a custom class!"); const auto* expected_type = c10::getCustomClassType>().get(); ivalue::checkCustomClassType(expected_type, type().get()); auto userObj = c10::static_intrusive_pointer_cast(obj->getSlot(0).toCapsule()); return userObj; } template T generic_to(IValue ivalue, _fake_type) { using ElemType = typename std::remove_pointer::type::element_type; return std::move(ivalue).toCustomClass(); } template tagged_capsule generic_to(IValue ivalue, _fake_type>) { return tagged_capsule{std::move(ivalue)}; } template c10::List generic_to(IValue ivalue, _fake_type>) { return impl::toTypedList(std::move(ivalue).toList()); } template static T createVectorLikeFromList(const c10::detail::ListImpl* impl) { T result; result.reserve(impl->list.size()); for (size_t i = 0, N = impl->list.size(); i < N; ++i) { result.push_back(impl->list[i].to()); } return result; } template static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { return createVectorLikeFromList>(impl); } template std::vector createVectorFromList(const c10::List& impl) { std::vector result; result.reserve(impl.size()); for (size_t i = 0, N = impl.size(); i < N; ++i) { result.push_back(impl[i]); } return result; } template OptionalArray generic_to(IValue ivalue, _fake_type>) { if (ivalue.isNone()) { return {}; } return createVectorFromList( std::move(ivalue).to>() ); } namespace detail { template std::array generic_to_array( IValue ivalue, _fake_type>, std::index_sequence) { // We need to do a deep copy of the array because there might be other // references to this same IValue that also use the list. We can't just // move the elements out. auto list = std::move(ivalue).to>(); TORCH_CHECK( list.size() == sizeof...(I), "Tried to convert a List with ", list.size(), " elements to a fixed-size array of size ", sizeof...(I)); return {list[I]...}; } } // namespace detail template std::array generic_to( IValue ivalue, _fake_type> ft) { return detail::generic_to_array(ivalue, ft, std::make_index_sequence()); } template c10::Dict generic_to( IValue ivalue, _fake_type>) { return impl::toTypedDict(std::move(ivalue).toGenericDict()); } template C10_DEPRECATED_MESSAGE( "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict instead.") std::unordered_map generic_to( IValue ivalue, _fake_type>) { std::unordered_map specialized_dict; for (const auto& item : std::move(ivalue).toGenericDict()) { specialized_dict[item.key().template to()] = item.value().template to(); } return specialized_dict; } template c10::optional generic_to(IValue ivalue, _fake_type>) { if (ivalue.isNone()) { return c10::nullopt; } return std::move(ivalue).to(); } namespace detail { template Tuple generic_to_tuple_impl( const ivalue::TupleElements& t, std::index_sequence) { return std::make_tuple( t[INDEX].to::type>()...); } } // namespace detail template < typename... Args, typename Indices = std::make_index_sequence, std::enable_if_t< !guts::disjunction< std::is_lvalue_reference..., guts::negation>...>::value, std::nullptr_t> = nullptr> std::tuple generic_to(IValue ivalue, _fake_type>) { const auto& vals = ivalue.toTupleRef().elements(); TORCH_CHECK(vals.size() == sizeof...(Args)); return detail::generic_to_tuple_impl>(vals, Indices{}); } template inline T IValue::to() && { return generic_to(std::move(*this), _fake_type{}); } template <> inline c10::optional IValue::to() && { // In the default implementation, the IValue is destroyed with std::move. // But if the unboxed type is optional we cannot destroy // the IValue. return generic_to(*this, _fake_type>{}); } template inline typename c10::detail::ivalue_to_const_ref_overload_return::type IValue::to() const& { return generic_to(*this, _fake_type{}); } inline c10::List IValue::toIntList() && { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } inline c10::List IValue::toIntList() const& { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toIntVector() const { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toIntVector on null intrusive_ptr IValue"); return createVectorFromList( static_cast(payload.u.as_intrusive_ptr)); } inline at::DimVector IValue::toDimVector() const { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toDimVector on null intrusive_ptr IValue"); return createVectorLikeFromList( static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toDoubleList() && { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } inline c10::List IValue::toDoubleList() const& { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toDoubleVector() const { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toDoubleVector on null intrusive_ptr IValue"); return createVectorFromList( static_cast(payload.u.as_intrusive_ptr)); } inline c10::List> IValue::toComplexDoubleList() && { AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); return c10::List>(moveToIntrusivePtr()); } inline c10::List> IValue::toComplexDoubleList() const& { AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); return c10::List>(toIntrusivePtr()); } inline std::vector> IValue::toComplexDoubleVector() const { AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toComplexDoubleVector on null intrusive_ptr IValue"); return createVectorFromList>( static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toBoolList() && { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } inline c10::List IValue::toBoolList() const& { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline c10::List IValue::toTensorList() && { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } inline c10::List IValue::toTensorList() const& { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline std::vector IValue::toTensorVector() const { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toTensorVector on null intrusive_ptr IValue"); return createVectorFromList( static_cast(payload.u.as_intrusive_ptr)); } inline c10::List> IValue::toOptionalTensorList() && { AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); return c10::List>(moveToIntrusivePtr()); } inline c10::List> IValue::toOptionalTensorList() const& { AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); return c10::List>(toIntrusivePtr()); } inline std::vector> IValue::toOptionalTensorVector() const { AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toOptionalTensorVector on null intrusive_ptr IValue"); return createVectorFromList>( static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toList() && { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); } inline c10::List IValue::toList() const& { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); return c10::List(toIntrusivePtr()); } inline c10::ArrayRef IValue::toListRef() const { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toListRef on null intrusive_ptr IValue"); return static_cast(payload.u.as_intrusive_ptr) ->list; } inline c10::Dict IValue::toGenericDict() && { AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); return c10::Dict(moveToIntrusivePtr()); } inline c10::Dict IValue::toGenericDict() const& { AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind()); return c10::Dict(toIntrusivePtr()); } inline c10::intrusive_ptr IValue::toTuple() && { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); return moveToIntrusivePtr(); } inline c10::intrusive_ptr IValue::toTuple() const& { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); return toIntrusivePtr(); } inline ivalue::Tuple& IValue::toTupleRef() const { AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toTupleRef on null intrusive_ptr IValue"); return *static_cast( payload.u.as_intrusive_ptr); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Tuple) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } template < typename... Args, std::enable_if_t< !guts::disjunction< std::is_lvalue_reference..., guts::negation>...>::value, std::nullptr_t>> inline IValue::IValue(const std::tuple& t) : IValue( std::move(c10::guts::apply(c10::ivalue::Tuple::create, t))) { } template < typename... Args, std::enable_if_t< !guts::disjunction< std::is_lvalue_reference..., guts::negation>...>::value, std::nullptr_t>> inline IValue::IValue(std::tuple&& t) : IValue( std::move(c10::guts::apply(c10::ivalue::Tuple::create, std::move(t)))) { } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::String) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(std::string v) : IValue(ivalue::ConstantString::create(std::move(v))) {} inline IValue::IValue(c10::impl::GenericList v) : tag(Tag::GenericList) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template > inline IValue::IValue(c10::List&& v) : IValue(impl::toList(std::move(v))) {} template > inline IValue::IValue(const c10::List& v) : IValue(impl::toList(v)) {} template > inline IValue::IValue(at::ArrayRef v) : IValue(c10::List()) { auto list = to>(); list.reserve(v.size()); for (const auto& e : v) { list.push_back(e); } } template > inline IValue::IValue(at::ArrayRef v) : IValue() { auto vi = c10::asIntArrayRefSlowOpt(v); if (vi.has_value()) { // This list is entirely integers; ensure it is typed as // an IntList so toIntList works *this = IValue(*vi); } else { // This list has SymInts; type it as a SymInt *this = IValue(impl::toList(c10::List())); auto list = to>(); list.reserve(v.size()); for (const auto& e : v) { list.push_back(e); } } } template > inline IValue::IValue(at::OptionalArrayRef mb_v) : IValue() { if (!mb_v.has_value()) return; *this = IValue(*mb_v); } template > inline IValue::IValue(const std::vector& v) : IValue() { *this = IValue(at::ArrayRef(v)); } template > inline IValue::IValue(const std::vector& v) : IValue(c10::List()) { auto list = to>(); list.reserve(v.size()); for (const auto& e : v) { list.push_back(e); } } template > inline IValue::IValue(c10::OptionalArrayRef v) : IValue() { if (v.has_value()) { *this = IValue(std::move(*v)); } } template inline IValue::IValue(std::array v) : IValue(c10::List()) { auto list = to>(); list.reserve(v.size()); for (auto& e : v) { list.push_back(std::move(e)); } } template > inline IValue::IValue(c10::IListRef v) : IValue() { constexpr bool boxed_type_constructs_ivalue = std::is_constructible::boxed_type>::value; // First, we try to use the boxed value. // If we fail (either it's not in the boxed state, or its boxed type // can not construct an IValue), we fallback to copying the list. if (boxed_type_constructs_ivalue && v.isBoxed()) { *this = IValue(impl::toList(v.toBoxed())); } else { c10::List list; list.reserve(v.size()); for (const auto& t : v) { list.push_back(t); } *this = IValue(impl::toList(std::move(list))); } } inline IValue::IValue(c10::impl::GenericDict v) : tag(Tag::GenericDict) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template inline IValue::IValue(c10::Dict v) : IValue(impl::toGenericDict(std::move(v))) {} template inline IValue::IValue(std::unordered_map v) : IValue(Dict()) { auto dict = to>(); dict.reserve(v.size()); for (auto& e : v) { dict.insert(std::move(e.first), std::move(e.second)); } } template > inline IValue::IValue(c10::optional v) : IValue() { if (v.has_value()) { *this = IValue(std::move(*v)); } } inline IValue::IValue(c10::nullopt_t) : IValue() {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Object) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::PyObject) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Enum) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue IValue::make_capsule( intrusive_ptr blob) { IValue iv; iv.tag = Tag::Capsule; iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); return iv; } template < typename T, std::enable_if_t::value, int>> IValue::IValue(c10::intrusive_ptr custom_class) { auto classType = []() { try { return c10::getCustomClassType>(); } catch (const c10::Error&) { throw c10::Error( "Trying to instantiate a class that isn't a registered custom class: " + std::string(c10::util::get_fully_qualified_type_name()), ""); } }(); auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); tag = Tag::Object; } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Future) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::RRef) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Quantizer) { payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } template inline IValue::IValue(c10::complex c) : tag(Tag::ComplexDouble) { auto v = c10::make_intrusive(c); payload.u.as_intrusive_ptr = v.release(); } inline const std::string& IValue::toStringRef() const { AT_ASSERT(isString(), "Expected String but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toStringRef on null intrusive_ptr IValue"); return static_cast( payload.u.as_intrusive_ptr) ->string(); } inline c10::optional> IValue:: toOptionalStringRef() const { if (isNone()) { return c10::nullopt; } AT_ASSERT(isString(), "Expected optional but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toOptionalStringRef on null intrusive_ptr IValue"); return std::reference_wrapper( static_cast(payload.u.as_intrusive_ptr) ->string()); } inline c10::string_view IValue::toStringView() const { AT_ASSERT(isString(), "Expected String but got ", tagKind()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "called toStringView on null intrusive_ptr IValue"); return static_cast( payload.u.as_intrusive_ptr) ->string_view(); } inline PyObject* IValue::toPyObject() const { return toPyObjectHolder()->getPyObject(); } template inline optional IValue::toOptional() { if (this->isNone()) { return nullopt; } return this->to(); } template inline optional IValue::toOptional() const { if (this->isNone()) { return nullopt; } return this->to(); } inline bool IValue::isCustomClass() const { return torch::isCustomClass(*this); } inline bool IValue::isSameIdentity(const IValue& rhs) const { // We choose to not use memcmp for payload check due to potential random // padding characters on union type // Semantics: // 1. Immutable primitive values of the same type (Int, Double, None, Bool, // Str) return value equality // 2. If it is a tensor type, we need to take undefined tensor into account // 3. Undefined_tensor is None and vice versa should be true // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when // the pointed-to object is the same. // 5. False for all other comparisons. if (this->isNone() && rhs.isNone()) { return true; } else if (this->isBool() && rhs.isBool()) { // for bool type, do equality check return this->toBool() == rhs.toBool(); } else if (this->isTensor() && rhs.isTensor()) { return this->payload.as_tensor.is_same(rhs.payload.as_tensor); } else if (this->isTensor() && rhs.isNone()) { // special case: undefined tensor and None are the same identity return !this->payload.as_tensor.defined(); } else if (this->isNone() && rhs.isTensor()) { // special case: undefined tensor and None are the same identity return !rhs.payload.as_tensor.defined(); } else if (this->isInt() && rhs.isInt()) { return this->toInt() == rhs.toInt(); } else if (this->isDouble() && rhs.isDouble()) { return this->toDouble() == rhs.toDouble(); } else if (this->isString() && rhs.isString()) { return this->toStringRef() == rhs.toStringRef(); } else { // for objects holding in IValue, do shallow compare on pointer address to // testify the identity return this->isIntrusivePtr() && rhs.isIntrusivePtr() && this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } } namespace ivalue { namespace detail { template IValue from_(T&& x, std::true_type) { return IValue(std::forward(x)); } template IValue from_(c10::intrusive_ptr x, std::false_type) { return IValue(std::move(x)); } template IValue from_(T&& /*x*/, std::false_type) { static_assert( guts::false_t::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)"); return IValue(); } } // namespace detail template IValue from(T&& x) { return detail::from_( std::forward(x), typename std::is_constructible::type{}); } } // namespace ivalue template <> struct MaybeOwnedTraits { using owned_type = IValue; using borrow_type = IValue; static borrow_type createBorrow(const owned_type& from) { if (!from.isPtrType()) { return from; } if (from.isTensor()) { return IValue(MaybeOwnedTraits::createBorrow(from.toTensor())); } else { return IValue(from.payload, from.tag); } } static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { lhs.clearToNone(); if (!rhs.isPtrType()) { lhs = rhs; } else if (rhs.isTensor()) { lhs = IValue(MaybeOwnedTraits::createBorrow(rhs.toTensor())); } else { lhs = IValue(rhs.payload, rhs.tag); } } static void destroyBorrow(borrow_type& toDestroy) { toDestroy.clearToNone(); } static const owned_type& referenceFromBorrow(const borrow_type& borrow) { return borrow; } static const owned_type* pointerFromBorrow(const borrow_type& borrow) { return &borrow; } static bool debugBorrowIsValid(const borrow_type&) { return true; } }; template <> struct IValue::TagType { static TORCH_API c10::TypePtr get(const IValue&); }; template <> struct IValue::TagType { static TORCH_API c10::TypePtr get(const IValue&); }; template TypePtr IValue::type() const { return IValue::TagType::get(*this); } } // namespace c10