#pragma once #include #include #include #include #include #include #include #include #include #include namespace c10 { class TORCH_API OperatorHandle; } namespace at { // Kind of record function scope; enum class C10_API_ENUM RecordScope : uint8_t { // c10/ATen ops, autograd nodes FUNCTION = 0, // Functions/nodes called from the autograd BACKWARD_FUNCTION, // TorchScript functions, methods TORCHSCRIPT_FUNCTION, // Kernel Function dtype Tag KERNEL_FUNCTION_DTYPE, // Torchbind custom class, CUSTOM_CLASS, // Generic Build Feature BUILD_FEATURE, // Kernel Function dtype Tag LITE_INTERPRETER, // User defined scope (e.g. with record_function()) USER_SCOPE, // Scopes for static runtime, a specialized TorchScript interpreter STATIC_RUNTIME_OP, STATIC_RUNTIME_MODEL, NUM_SCOPES, // must be the last in the list }; } // namespace at namespace std { template <> struct hash { size_t operator()(const at::RecordScope& sc) const { return static_cast(sc); } }; } // namespace std namespace at { struct TORCH_API StringView { StringView() : StringView(nullptr) {} explicit StringView(const char* str_ptr) : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} explicit StringView(std::string str) : owned_str_ptr_(std::make_shared(std::move(str))), str_ptr_(owned_str_ptr_->c_str()) {} const char* str() const { return str_ptr_; } friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { os << dt.str(); return os; } friend bool operator==(const StringView& lhs, const StringView& rhs) { return strcmp(lhs.str(), rhs.str()) == 0; } friend bool operator!=(const StringView& lhs, const StringView& rhs) { return !(lhs == rhs); } private: std::shared_ptr owned_str_ptr_; const char* str_ptr_; }; // Soft limit on the number of callbacks to use; constexpr std::size_t kSoftLimitCallbacks = 4; // An abstract base class for various observer contexts that can be attached to // the RecordFunction. struct ObserverContext { virtual ~ObserverContext() {} protected: ObserverContext() {} }; typedef c10::SmallVector CallbackHandles; typedef c10::SmallVector, kSoftLimitCallbacks> ObserverContextList; typedef uint64_t RecordFunctionHandle; struct RecordFunction; // // PyTorch callbacks/observers API: // /** * RecordFunctionCallback represents a pair of callbacks to be used with * RecordFunction, members: * start, end - the callbacks to run when entering and exiting the scope; * optionally, the start callback may return an ObserverContext which will * be passed to the end callback, use appropriate constructor accordingly. * needs_inputs - whether the callbacks need the inputs passed from the * observed function/range; NOTE: passing the inputs incurs an additional * overhead; sampling_probability - if not 1.0, then the callback is * probabilistically sampled to run; NOTE: start and end callbacks always run as * a pair and are sampled together; scopes - types of scopes to execute the * callbacks on (see RecordScope); passing empty set means the callbacks will be * executed for all possible scope types should_run - optional function that * returns whether this callback should run; overwrites the effect of setting * sampling_probability */ class TORCH_API RecordFunctionCallback { public: using StartCallback = std::unique_ptr (*)(const RecordFunction&); using EndCallback = void (*)(const RecordFunction&, ObserverContext*); // This interface supports observers that require passing an ObserverContext // between start and end callbacks. explicit RecordFunctionCallback( StartCallback start, EndCallback end = nullptr) : start_(start), end_(end) { scopes_.fill(true); } RecordFunctionCallback& needsInputs(bool needs_inputs) { needs_inputs_ = needs_inputs; return *this; } RecordFunctionCallback& needsOutputs(bool needs_outputs) { needs_outputs_ = needs_outputs; return *this; } RecordFunctionCallback& needsIds(bool needs_ids) { needs_ids_ = needs_ids; return *this; } RecordFunctionCallback& samplingProb(double sampling_prob) { TORCH_CHECK( sampling_prob >= 0.0 && sampling_prob <= 1.0, "Invalid sampling probability"); sampling_prob_ = sampling_prob; return *this; } RecordFunctionCallback& scopes( const std::unordered_set>& scopes) { if (!scopes.empty()) { scopes_.fill(false); for (auto sc : scopes) { scopes_[static_cast(sc)] = true; } } else { scopes_.fill(true); } return *this; } bool needsInputs() const { return needs_inputs_; } bool needsOutputs() const { return needs_outputs_; } bool needsIds() const { return needs_ids_; } double samplingProb() const { return sampling_prob_; } bool checkScope(RecordScope sc) const { return scopes_[(size_t)sc]; } StartCallback start() const { return start_; } EndCallback end() const { return end_; } private: StartCallback start_; EndCallback end_; double sampling_prob_ = 1.0; std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; bool needs_inputs_ = false; bool needs_outputs_ = false; bool needs_ids_ = false; }; // Notes: // - two types of callbacks are provided: thread local and global // - thread local callbacks are added/removed only for the given thread // and are stored locally for each thread and separately from the list // of the global callbacks // - global callbacks are stored in a single per process list and are // invoked by every RecordFunction, in addition to the thread local // callbacks specific to the given thread // - we allow the added callbacks to be sampled, by specifying a sampling // probability for each callback pair, if the start callback is // not picked to run, the corresponding end callback won't be called // - a typical use case for the global callbacks is passive monitoring // in the background (e.g. fleet-wide monitoring), without focusing on // the specific piece of code // - in contrast, thread local callbacks are enabled locally, on demand, // for the specific piece of code (range) and are not sampled // - a typical use case for thread local callbacks is profiler and code // execution tracer // - note, thread local callbacks are automatically propagated with // ThreadLocalState across JIT continuations and async tasks (at::launch) typedef uint64_t CallbackHandle; constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; // It is unnecessary to use atomic operations for enabling // thread-local function callbacks. Moreover, it prevents saving to // ThreadLocalState because std::atomic is non-copyable. struct RecordFunctionCallbacksEntry { RecordFunctionCallbacksEntry(RecordFunctionCallback&& cb, CallbackHandle h) : callback_(cb), handle_(h) {} RecordFunctionCallback callback_; bool enabled_{true}; CallbackHandle handle_; }; // Holds pairs (callbacks, unique_id) using RecordFunctionCallbacks = std::vector; // Generated by the callback managers to determine which functions to run. struct StepCallbacks { StepCallbacks() = default; StepCallbacks(uint64_t thread_id, RecordScope scope) : thread_id_{thread_id}, scope_{scope} {} bool empty() const { return callbacks_.empty(); } struct StartEndPair { RecordFunctionCallback::StartCallback start_; RecordFunctionCallback::EndCallback end_; }; using StartEndPairs = c10::SmallVector; StartEndPairs callbacks_; uint64_t thread_id_{0}; RecordScope scope_{RecordScope::FUNCTION}; bool needs_inputs_{false}; bool needs_outputs_{false}; bool needs_ids_{false}; }; struct TORCH_API RecordFunction { // Default constructor is used with before function called afterwards: // scope - record scope that this function tracks // pre_sampled - whether this RecordFunction was already pre-sampled with // kLowProb probability explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); explicit RecordFunction(StepCallbacks&& step_callbacks); template void before( F fn, c10::ArrayRef args, int64_t current_sequence_nr = -1) { if (!isActive()) { return; } inputs_ = args; before(fn, current_sequence_nr); } template void before( F fn, const std::vector* args, int64_t current_sequence_nr = -1) { before( std::move(fn), c10::ArrayRef(args->data(), args->size()), current_sequence_nr); } // Destructor calls end callbacks virtual ~RecordFunction(); RecordFunction(const RecordFunction&) = delete; RecordFunction& operator=(const RecordFunction&) = delete; const char* name() const; int64_t seqNr() const { return sequence_nr_; } c10::ArrayRef inputs() const { #ifndef NDEBUG TORCH_INTERNAL_ASSERT_DEBUG_ONLY( inputs_valid_, "Called inputs() outside RecordFunction start callback"); #endif return inputs_; } const std::vector& outputs() const { return outputs_; } void setOutputs(std::vector&& outputs) { outputs_ = std::move(outputs); } void setOutputs(c10::ArrayRef outputs) { outputs_ = outputs.vec(); } size_t num_inputs() const; size_t num_outputs() const; // Retrieves the thread_id that this RecordFunction ran start callbacks with. // Useful for writing thread safe end callbacks that may be potentially // executed in a different thread (async ops) uint64_t threadId() const { return step_callbacks_.thread_id_; } // For backward functions - thread id of the corresponding forward function, // or zero otherwise; // used alongside with sequence number to correlate backward functions with // the forward ones uint64_t forwardThreadId() const { return fwd_thread_id_; } void setForwardThreadId(uint64_t thread_id) { fwd_thread_id_ = thread_id; } RecordScope scope() const { return step_callbacks_.scope_; } // Returns logical thread_id for the current thread static uint64_t currentThreadId(); // Internal functions, do not use directly; // used in python's context manager // before functions initialize RecordFunction members and call // start callbacks using schema_ref_t = std::reference_wrapper; void before(const char* name, int64_t sequence_nr = -1); void before(std::string name, int64_t sequence_nr = -1); void before(schema_ref_t schema, int64_t sequence_nr = -1); // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); // Gets node ID for distributed profiling static int64_t getDefaultNodeId(); // Calls end callbacks. After end(), accessors will no longer provide useful // results. void end(); // Internal-only, used only force async event for distributed events // profiling. void _setAsync(); // Returns whether this RecordFunction corresponds to an async event orn ot. bool isAsync() const; // Internal-only, used to denote out variant used for Static Runtime execution void _setStaticRuntimeOutVariant(); bool isStaticRuntimeOutVariant() const; RecordFunctionHandle handle() const { return handle_; } c10::optional operator_name() const; // This method returns a copy of the FunctionSchema and can be expensive. c10::optional operator_schema() const; void setHandle(RecordFunctionHandle handle) { handle_ = handle; } // Whether this RecordFunction runs any callbacks. bool isActive() const { return !step_callbacks_.empty(); } bool needsInputs() const { return step_callbacks_.needs_inputs_; } bool needsOutputs() const { return step_callbacks_.needs_outputs_; } int64_t debugHandle() const { return debug_handle_; } void setDebugHandle(int64_t debug_handle) { debug_handle_ = debug_handle; } void invalidateInputs() { #ifndef NDEBUG inputs_valid_ = false; #endif } private: void runStartCallbacks(); StepCallbacks step_callbacks_; // In cases when RecordFunction might be active but we chose not to // use the observers (e.g. operator is not observed), this boolean // flag is used to check whether the start callbacks were called bool called_start_callbacks_ = false; #ifndef NDEBUG bool inputs_valid_ = false; #endif // Stores various ObserverContext objects with event metadata for callbacks. ObserverContextList ctx_; c10::variant fn_; int64_t sequence_nr_ = -1; c10::ArrayRef inputs_; std::vector outputs_; // For backward functions - thread id of the the forward function uint64_t fwd_thread_id_ = 0; // Unique id for this RecordFunction, used in callbacks to track start // and end of ranges RecordFunctionHandle handle_{0}; // Whether this record_function corresponds to an async event or not. Async // events can complete in different threads or follow a future-like pattern // of use. bool is_async_{false}; // Debug handles are used for lazy annotation of module hierarchy // and callstack. // This is specifically is useful for mobile runtime, where generated // debug handles can be lazily symbolicated using debug information int64_t debug_handle_{-1}; // Whether this RecordFunction is used for an out variant run with // Static Runtime bool is_static_runtime_out_variant_{false}; }; TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); TORCH_API c10::optional getStepCallbacksUnlessEmpty( RecordScope scope); namespace detail { template void record_function_with_scope( RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) { if (guard.needsInputs()) { guard.before( fn, c10::ArrayRef(inputs.data(), inputs.size()), std::forward(args)...); } else { guard.before(fn, std::forward(args)...); } } template void record_function_with_scope_and_debug_handle( RecordFunction& guard, F fn, int64_t debug_handle, const Inputs& inputs, Args&&... args) { guard.setDebugHandle(debug_handle); if (guard.needsInputs()) { guard.before( fn, c10::ArrayRef(inputs.data(), inputs.size()), std::forward(args)...); } else { guard.before(fn, std::forward(args)...); } } template void record_function_with_scope( RecordFunction& guard, F fn, c10::ArrayRef inputs, Args&&... args) { return record_function_with_scope< c10::ArrayRef, F, Args...>(guard, std::move(fn), inputs, std::forward(args)...); } template void record_function_with_scope_and_debug_handle( RecordFunction& guard, F fn, int64_t debug_handle, c10::ArrayRef inputs, Args&&... args) { return record_function_with_scope_and_debug_handle< c10::ArrayRef, F, Args...>( guard, std::move(fn), debug_handle, inputs, std::forward(args)...); } } // namespace detail // optional argument - function's seq_no #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ at::RecordFunction guard(scope); \ if (guard.isActive()) { \ ::at::detail::record_function_with_scope( \ guard, fn, inputs, ##__VA_ARGS__); \ } #define RECORD_FUNCTION(fn, inputs, ...) \ RECORD_FUNCTION_WITH_SCOPE( \ at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) #define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) // Custom user scopes in C++; similar to Python's 'with record_function("..."):' #define RECORD_USER_SCOPE(fn) \ RECORD_FUNCTION_WITH_SCOPE( \ at::RecordScope::USER_SCOPE, fn, c10::ArrayRef{}) // RECORD_USER_SCOPE with inputs #define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) // Helper macro to pass in debug handle that is used to // post process events #define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ scope, fn, debug_handle, inputs, ...) \ at::RecordFunction guard(scope); \ if (guard.isActive()) { \ ::at::detail::record_function_with_scope_and_debug_handle( \ guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ } // Helper macros to record LITE INTERPETER scope events with debug handles #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ fn, debug_handle, inputs) \ RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) /** * addThreadLocalCallback adds a thread local callback to run with * RecordFunction, returns handle to use with removeThreadLocalCallback */ TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); /** * hasThreadLocalCallbacks returns whether there're callbacks registered * with addThreadLocalCallback */ TORCH_API bool hasThreadLocalCallbacks(); /** * clearThreadLocalCallbacks removes all thread local callbacks */ TORCH_API void clearThreadLocalCallbacks(); /** * addGlobalCallback adds a global callback to run with RecordFunction: * * only during the program initialization */ TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); /** * removeCallback removes a callback given the handle returned by * addThreadLocalCallback or addGlobalCallback; * * no other code can run simultaneously */ TORCH_API void removeCallback(CallbackHandle handle); /** * Prevent the given callback from executing. If handle is invalid, * does nothing. */ TORCH_API void disableCallback(CallbackHandle handle); /** * Allow the given callback, previously disabled with disableCallback, to * execute again. If handle is invalid, does nothing. */ TORCH_API void reenableCallback(CallbackHandle handle); /** * hasGlobalCallbacks returns whether there're global callbacks * registered with pushGlobalCallback */ TORCH_API bool hasGlobalCallbacks(); /** * clearGlobalCallbacks removes all global callbacks */ TORCH_API void clearGlobalCallbacks(); // for both thread local and global callbacks TORCH_API bool hasCallbacks(); TORCH_API void clearCallbacks(); /** * enableRecordFunction enables RecordFunction thread locally */ TORCH_API void enableRecordFunction(bool enable = true); /** * isRecordFunctionEnabled returns whether RecordFunction * is enabled thread locally */ TORCH_API bool isRecordFunctionEnabled(); class TORCH_API RecordFunctionGuard { public: explicit RecordFunctionGuard(bool is_enabled = true) : prev_value_(isRecordFunctionEnabled()) { enableRecordFunction(is_enabled); } virtual ~RecordFunctionGuard() { enableRecordFunction(prev_value_); } private: bool prev_value_ = false; }; class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { public: DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} virtual ~DisableRecordFunctionGuard() {} }; struct TORCH_API RecordFunctionTLS { // Thread local vector of callbacks, holds pairs (callbacks, unique_id); // must be sorted in increasing handles order RecordFunctionCallbacks sorted_tls_callbacks_; bool tls_record_function_enabled_ = true; }; TORCH_API const RecordFunctionTLS& get_record_function_tls_(); TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); TORCH_API void set_record_function_seed_for_testing(uint32_t seed); } // namespace at