#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { void packGradient(const Gradient& gradient, Node* dnode); bool needsGradient(const std::shared_ptr& graph); void runOptimization( std::shared_ptr& graph, bool unroll_non_constant_loops = true, bool const_prop_user_classes = true); void runNondiffOptimization( std::shared_ptr& graph, bool strict_fuser_check = false); void debugSetAutodiffSubgraphInlining(bool state); bool TORCH_API getAutodiffSubgraphInlining(); void debugSetFusionGroupInlining(bool state); bool getFusionGroupInlining(); // Tunable parameters for deciding when to create/keep subgraphs of // differentiable code const size_t autodiffSubgraphNodeThreshold = 2; const size_t autodiffSubgraphInlineThreshold = 5; // a Graph can be created via tracing, or via a language-based frontend // GraphExecutor runs it. It can run the same graph on many different sizes // and different requires_grad states, and handles specializations for each // situation. GraphExecutor is completely unaware of tracing or module // parameters to keep the tracing concerns separated. struct GraphExecutorImplBase { static std::shared_ptr prepareGraph( const std::shared_ptr& graph) { auto copy = graph->copy(); EraseShapeInformation(copy); return copy; } GraphExecutorImplBase( const std::shared_ptr& graph, std::string function_name) : graph(prepareGraph(graph)), function_name_(std::move(function_name)), num_inputs(this->graph->inputs().size()), num_outputs(this->graph->outputs().size()) {} // entry point where execution begins void run(Stack& stack); c10::intrusive_ptr runAsync( Stack& stack, TaskLauncher taskLauncher = at::launch); virtual const ExecutionPlan& getPlanFor( Stack& stack, std::optional remaining_bailout_depth = std::nullopt) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; virtual bool isOptimized() const { return false; } protected: friend struct GraphExecutor; // The unoptimized starting graph. This field is effectively const, but we // can't make it so because Graph::copy() is not const (and making it const is // not that easy at this point). // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr graph; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::string function_name_; // If false, we'll run the graph as we get it, without any optimizations. // Useful for debugging. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const size_t num_inputs; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const size_t num_outputs; // GraphExecutors can be accessed from multiple threads, so this thread needs // to be held every time we access the fallback or plan_cache. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::mutex compile_mutex; }; } // namespace torch::jit