#pragma once #include #include #include #include #include #include #include template inline std::vector makeStack(Inputs&&... inputs) { return {std::forward(inputs)...}; } inline at::Tensor dummyTensor(c10::DispatchKeySet ks, bool requires_grad=false) { auto* allocator = c10::GetCPUAllocator(); int64_t nelements = 1; auto dtype = caffe2::TypeMeta::Make(); int64_t size_bytes = nelements * dtype.itemsize(); auto storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, allocator->allocate(size_bytes), allocator, /*resizable=*/true); at::Tensor t = at::detail::make_tensor(storage_impl, ks, dtype); // TODO: We add this to simulate the ideal case where we only have Autograd backend keys // on Tensor when it requires grad. But currently Autograd keys are added in TensorImpl // constructor by default. if (!requires_grad) { t.unsafeGetTensorImpl()->remove_autograd_key(); } return t; } inline at::Tensor dummyTensor(c10::DispatchKey dispatch_key, bool requires_grad=false) { return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad); } template inline std::vector callOp(const c10::OperatorHandle& op, Args... args) { auto stack = makeStack(std::forward(args)...); op.callBoxed(&stack); return stack; } template inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) { return op.typed().call(std::forward(args)...); } template inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::DispatchKey dispatchKey, Args... args) { return op.typed().callWithDispatchKey(dispatchKey, std::forward(args)...); } template inline Result callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle& op, c10::DispatchKeySet ks, Args... args) { return op.typed().redispatch(ks, std::forward(args)...); } inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) { auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); EXPECT_ANY_THROW( callOp(*op, dummyTensor(dispatch_key), 5); ); } inline void expectDoesntFindOperator(const char* op_name) { auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); EXPECT_FALSE(op.has_value()); } template inline void expectThrows(Functor&& functor, const char* expectMessageContains) { try { std::forward(functor)(); } catch (const Exception& e) { EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains)); return; } ADD_FAILURE() << "Expected to throw exception containing \"" << expectMessageContains << "\" but didn't throw"; } template void expectListEquals(c10::ArrayRef expected, std::array actual) { EXPECT_EQ(expected.size(), actual.size()); for (const auto i : c10::irange(expected.size())) { EXPECT_EQ(expected[i], actual[i]); } } template void expectListEquals(c10::ArrayRef expected, c10::ArrayRef actual) { EXPECT_EQ(expected.size(), actual.size()); for (const auto i : c10::irange(expected.size())) { EXPECT_EQ(expected[i], actual[i]); } } template void expectListEquals(c10::ArrayRef expected, c10::List actual) { EXPECT_EQ(expected.size(), actual.size()); for (const auto i : c10::irange(expected.size())) { EXPECT_EQ(expected[i], actual.get(i)); } } template void expectListEquals(c10::ArrayRef expected, std::vector actual) { EXPECT_EQ(expected.size(), actual.size()); for (const auto i : c10::irange(expected.size())) { EXPECT_EQ(expected[i], actual[i]); } } // NB: This is not really sound, but all of the type sets constructed here // are singletons so it's fine static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) { return legacyExtractDispatchKey(t.key_set()); }