executor.h

xiaoxiao2021-02-28  77

#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ #define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_ #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/session_state.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { class StepStatsCollector; // Executor runs a graph computation. // 执行器运行图形计算。 // Example: // Graph* graph = ...; // ... construct graph ... // Executor* executor; // TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); // Rendezvous* rendezvous = NewNaiveRendezvous(); // TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); // TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); // TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); // ... ... // // Multiple threads can call Executor::Run concurrently. class Executor { public: virtual ~Executor() {} // RunAsync() executes the graph computation. "done" is run when the graph computation completes. // If any error happens during the computation, "done" is run and the error is passed to "done". // RunAsync() 执行图形计算。 "done" 在图形计算完成时运行。 // 如果在计算过程中出现任何错误,运行"done",并将错误传递给"done"。 // // RunAsync() is given a few arguments in Args. The caller must ensure objects passed in Args // (rendezvous, stats_collector, etc.) are alive at least until done is invoked. // All pointers to the argument objects can be nullptr. //  在 Args 中给出了 RunAsync() 的一些参数 。 // 调用者必须确保在 Args(rendezvous,stats_collector 等)中传递的对象至少在调用完成之前是活着的。 // 参数对象的所有指针都可以为 nullptr。 // // "step_id" is a process-wide unique identifier for the step being run. Executors on different // devices may receive the same step_id in the case that a step runs Ops on more than one device. // The step_id is used for tracking resource usage of a given step. // "step_id" 是正在运行的步骤的流程范围的唯一标识符。在步骤在多个设备上运行 Ops 的情况下, // 不同设备上的执行程序可能会收到相同的 step_id。 step_id 用于跟踪给定步骤的资源使用情况。 // // RunAsync() uses the given "rendezvous", if not null, as the mechanism to communicate // inputs and outputs of the underlying graph computation. // 如果不为空,RunAsync() 使用给定的 "rendezvous" 作为通信底层图形计算的输入和输出的机制。 // // RunAsync() calls "stats_collector", if not null, to keep track of stats. // This allows us to collect statistics and traces on demand. // 如果不为空,RunAsync() 调用 "stats_collector" 来跟踪统计信息。 // 这使我们能够根据需要收集统计数据和痕迹。 // // RunAsync() is provided a "call_frame", if the executor is used for executing a function, // is used to pass arguments and return values between the caller and the callee. // 如果 executor 用于执行一个函数,RunAsync() 被提供了一个 "call_frame" // 用于传递参数并在调用者和被调用者之间返回值。 // // RunAsync() uses "cancellation_manager", if not nullptr, to register callbacks that // should be called if the graph computation is cancelled. // Note that the callbacks merely unblock any long-running computation, // and a cancelled step will terminate by returning/calling the DoneCallback as usual. // 如果不为空,RunAsync() 使用 "cancellation_manager"注册回调,当图形计算被取消应该调用。 // 请注意,回调只是解除阻止任何长时间运行的计算,一般取消的步骤将通过返回/调用 DoneCallback 来终止。 // // RunAsync() dispatches closures to "runner". // Typically, "runner" is backed up by a bounded threadpool. // RunAsync() 将闭包(closures)分派到 "runner"。通常,"runner" 由有界线程池备份。 struct Args { int64 step_id = 0; Rendezvous* rendezvous = nullptr; StepStatsCollector* stats_collector = nullptr; FunctionCallFrame* call_frame = nullptr; CancellationManager* cancellation_manager = nullptr; SessionState* session_state = nullptr; TensorStore* tensor_store = nullptr; ScopedStepContainer* step_container = nullptr; // If true, calls Sync() on the device. bool sync_on_finish = false; typedef std::function<void()> Closure; typedef std::function<void(Closure)> Runner; Runner runner = nullptr; // A callback that is invoked each time a node has finished executing. typedef std::function<Status(const string& node_name, const int output_slot, const Tensor* tensor, const bool is_ref, OpKernelContext* ctx)> NodeOutputsCallback; NodeOutputsCallback node_outputs_cb = nullptr; }; typedef std::function<void(const Status&)> DoneCallback; virtual void RunAsync(const Args& args, DoneCallback done) = 0; // Synchronous wrapper for RunAsync(). Status Run(const Args& args) { Status ret; Notification n; RunAsync(args, [&ret, &n](const Status& s) { ret = s; n.Notify(); }); n.WaitForNotification(); return ret; } }; // Creates an Executor that computes the given "graph". // // If successful, returns the constructed executor in "*executor". The // caller keeps the ownership of "device". The returned executor takes // the ownership of "graph". Otherwise, returns an error status. // // "params" provides a set of context for the executor. We expect that // different context would provide different implementations. struct LocalExecutorParams { Device* device; // The library runtime support. FunctionLibraryRuntime* function_library = nullptr; // create_kernel returns an instance of op kernel based on NodeDef. // delete_kernel is called for every kernel used by the executor // when the executor is deleted. std::function<Status(const NodeDef&, OpKernel**)> create_kernel; std::function<void(OpKernel*)> delete_kernel; Executor::Args::NodeOutputsCallback node_outputs_cb; }; ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, Executor** executor); // A class to help run multiple executors in parallel and wait until all of them are complete. // // ExecutorBarrier deletes itself after the function returned by Get() is called. class ExecutorBarrier { public: typedef std::function<void(const Status&)> StatusCallback; // Create an ExecutorBarrier for 'num' different executors. // // 'r' is the shared Rendezvous object that is used to communicate state. // If any of the executors experiences an error, the rendezvous object will be aborted exactly once. // // 'done' is called after the last executor completes, and ExecutorBarrier is deleted. // ExecutorBarrier(int num, Rendezvous* r, StatusCallback done) : rendez_(r), done_cb_(done), pending_(num) {} ~ExecutorBarrier() {} // Returns a closure that Executors must call when they are done computing, // passing the status of their execution as an argument. StatusCallback Get() { return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); } private: Rendezvous* rendez_ = nullptr; StatusCallback done_cb_ = nullptr; mutable mutex mu_; int pending_ GUARDED_BY(mu_) = 0; Status status_ GUARDED_BY(mu_); void WhenDone(const Status& s) { bool error = false; Rendezvous* error_rendez = nullptr; StatusCallback done = nullptr; Status status; { mutex_lock l(mu_); // If we are the first error encountered, mark the status appropriately and later // trigger an abort of the Rendezvous object by this thread only. if (status_.ok() && !s.ok()) { error = true; error_rendez = rendez_; error_rendez->Ref(); status_ = s; } // If this is the last call to WhenDone, call the final callback below. if (--pending_ == 0) { CHECK(done_cb_ != nullptr); done = done_cb_; done_cb_ = nullptr; } status = status_; } if (error) { error_rendez->StartAbort(status); error_rendez->Unref(); } if (done != nullptr) { delete this; done(status); } } TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier); }; // A few helpers to facilitate create/delete kernels. // Creates a kernel based on "ndef" on device "device". The kernel can access the functions // in the "flib". The caller takes ownership of returned "*kernel". Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, int graph_def_version, OpKernel** kernel); // Deletes "kernel" returned by CreateKernel. void DeleteNonCachedKernel(OpKernel* kernel); } // end namespace tensorflow #endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_

#include "tensorflow/core/common_runtime/executor.h" #include <atomic> #include <deque> #include <memory> #include <string> #include <unordered_map> #include <vector> #include "tensorflow/core/common_runtime/costmodel_manager.h" #include "tensorflow/core/common_runtime/pending_counts.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_segment.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/manual_constructor.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" namespace tensorflow { namespace { // 1-D, 0 element tensor. static const Tensor* const kEmptyTensor = new Tensor; bool IsInitializationOp(const Node* node) { return node->op_def().allows_uninitialized_input(); } // Sets the timeline_label field of *node_stats, using data from *node. // Returns true iff the node is a transfer node. // TODO(tucker): merge with the DetailText function in session.cc in a common location. bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) { bool is_transfer_node = false; string memory; for (auto& all : node_stats->memory()) { int64 tot = all.total_bytes(); if (tot >= 0.1 * 1048576.0) { int64 peak = all.peak_bytes(); if (peak > 0) { memory = strings::StrCat(memory, "[", all.allocator_name(), strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, peak / 1048576.0)); } else { memory = strings::StrCat(memory, "[", all.allocator_name(), strings::Printf(" %.1fMB] ", tot / 1048576.0)); } } } const NodeDef& def = node->def(); string text = ""; if (IsSend(node)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); string recv_device; TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device)); text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", tensor_name, " @", recv_device); is_transfer_node = true; } else if (IsRecv(node)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); string send_device; TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device)); text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", tensor_name, " @", send_device); is_transfer_node = true; } else { text = strings::StrCat( memory, def.name(), " = ", def.op(), "(", str_util::Join( std::vector<StringPiece>(def.input().begin(), def.input().end()), ", "), ")"); } node_stats->set_timeline_label(text); return is_transfer_node; } // Helper routines for collecting step stats. namespace nodestats { inline int64 NowInUsec() { return Env::Default()->NowMicros(); } void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); } void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); } void SetOpStart(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetOpEnd(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetAllEnd(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetOutput(NodeExecStats* nt, int slot, const Tensor* v) { DCHECK(v); NodeOutput* no = nt->add_output(); no->set_slot(slot); v->FillDescription(no->mutable_tensor_description()); } void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) { for (const auto& allocator_pair : ctx->wrapped_allocators()) { AllocatorMemoryUsed* memory = nt->add_memory(); // retrieving the sizes from the wrapped allocator removes the executor's reference to it, // so allocator_pair.second must not be dereferenced again after this statement auto sizes = allocator_pair.second->GetSizesAndUnRef(); memory->set_allocator_name(allocator_pair.first->Name()); memory->set_total_bytes(std::get<0>(sizes)); if (allocator_pair.first->TracksAllocationSizes()) { memory->set_peak_bytes(std::get<1>(sizes)); memory->set_live_bytes(std::get<2>(sizes)); } } auto* ms = nt->mutable_memory_stats(); ms->set_host_temp_memory_size(ctx->host_temp_memory_size()); ms->set_device_temp_memory_size(ctx->device_temp_memory_size()); for (const auto& alloc_id : ctx->host_persistent_alloc_ids()) { ms->mutable_host_persistent_tensor_alloc_ids()->Add(alloc_id); } for (const auto& alloc_id : ctx->device_persistent_alloc_ids()) { ms->mutable_device_persistent_tensor_alloc_ids()->Add(alloc_id); } ms->set_host_persistent_memory_size(ctx->host_persistent_memory_allocated()); ms->set_device_persistent_memory_size( ctx->device_persistent_memory_allocated()); } void SetReferencedTensors(NodeExecStats* nt, const TensorReferenceVector& tensors) { // be careful not to increment the reference count on any tensor // while recording the information for (size_t i = 0; i < tensors.size(); ++i) { AllocationDescription* description = nt->add_referenced_tensor(); tensors.at(i).FillDescription(description); } } } // namespace nodestats class ExecutorImpl; class GraphView; struct EdgeInfo { int dst_id; int output_slot : 31; // true if this is the last info for output_slot in the EdgeInfo list. bool is_last : 1; int input_slot; }; struct NodeItem { NodeItem() {} // A graph node. const Node* node = nullptr; // The kernel for this node. OpKernel* kernel = nullptr; bool kernel_is_expensive : 1; // True iff kernel->IsExpensive() bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr bool is_merge : 1; // True iff IsMerge(node) bool is_enter : 1; // True iff IsEnter(node) bool is_exit : 1; // True iff IsExit(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_sink : 1; // True iff IsSink(node) // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) bool is_enter_exit_or_next_iter : 1; // Cached values of node->num_inputs() and node->num_outputs(), to avoid levels of indirection. int num_inputs; int num_outputs; // ExecutorImpl::tensors_[input_start] is the 1st positional input for this node. int input_start = 0; // Number of output edges. int num_output_edges; PendingCounts::Handle pending_id; const EdgeInfo* output_edge_list() const { return output_edge_base(); } // ith output edge. const EdgeInfo& output_edge(int i) const { DCHECK_GE(i, 0); DCHECK_LT(i, num_output_edges); return output_edge_base()[i]; } DataType input_type(int i) const { DCHECK_LT(i, num_inputs); return static_cast<DataType>(input_type_base()[i]); } DataType output_type(int i) const { DCHECK_LT(i, num_outputs); return static_cast<DataType>(output_type_base()[i]); } // Return array of per-output allocator attributes. const AllocatorAttributes* output_attrs() const { return output_attr_base(); } private: friend class GraphView; // Variable length section starts immediately after *this // (uint8 is enough for DataType). // EdgeInfo out_edges[num_out_edges]; // AllocatorAttributes output_attr[num_outputs]; // uint8 input_type[num_inputs]; // uint8 output_type[num_outputs]; // Return pointer to variable length section. char* var() const { return const_cast<char*>(reinterpret_cast<const char*>(this) + sizeof(NodeItem)); } EdgeInfo* output_edge_base() const { return reinterpret_cast<EdgeInfo*>(var()); } AllocatorAttributes* output_attr_base() const { return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) * num_output_edges); } uint8* input_type_base() const { return reinterpret_cast<uint8*>(var() + sizeof(EdgeInfo) * num_output_edges + sizeof(AllocatorAttributes) * num_outputs); } uint8* output_type_base() const { return reinterpret_cast<uint8*>( var() + sizeof(EdgeInfo) * num_output_edges + sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs); } TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); }; typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec; typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; // Immutable view of a Graph organized for efficient execution. class GraphView { public: GraphView() : space_(nullptr) {} ~GraphView(); void Initialize(const Graph* g); Status SetAllocAttrs(const Graph* g, const Device* device); NodeItem* node(int id) const { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); uint32 offset = node_offsets_[id]; return ((offset == kuint32max) ? nullptr : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); } private: char* InitializeNode(char* ptr, const Node* n); size_t NodeItemBytes(const Node* n); int32 num_nodes_ = 0; uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()" // node_offsets_[id] holds the byte offset for node w/ "id" in space_ char* space_; // NodeItem objects are allocated here TF_DISALLOW_COPY_AND_ASSIGN(GraphView); }; class ExecutorImpl : public Executor { public: ExecutorImpl(const LocalExecutorParams& p, const Graph* g) : params_(p), graph_(g), gview_() { CHECK(p.create_kernel != nullptr); CHECK(p.delete_kernel != nullptr); } ~ExecutorImpl() override { for (int i = 0; i < graph_->num_node_ids(); i++) { NodeItem* item = gview_.node(i); if (item != nullptr) { params_.delete_kernel(item->kernel); } } for (auto fiter : frame_info_) { delete fiter.second; } delete graph_; } Status Initialize(); // Process all Nodes in the current graph, attempting to infer the // memory allocation attributes to be used wherever they may allocate // a tensor buffer. Status SetAllocAttrs(); void RunAsync(const Args& args, DoneCallback done) override; private: friend class ExecutorState; struct ControlFlowInfo { gtl::FlatSet<string, HashStr> unique_frame_names; std::vector<string> frame_names; }; struct FrameInfo { FrameInfo() : input_count(0), total_inputs(0), pending_counts(nullptr), nodes(nullptr) {} // The total number of inputs to a frame. int input_count; // The total number of input tensors of a frame. // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. int total_inputs; // Used to determine the next place to allocate space in the // pending_counts data structure we'll eventually construct // 用于确定下一个位置,以分配我们将最终构建的 pending_counts 数据结构中的空间, PendingCounts::Layout pending_counts_layout; // Each frame has its own PendingCounts only for the nodes in the frame. // 每个 frame 都有自己的 PendingCounts,仅用于 frame 中的节点。 PendingCounts* pending_counts; // Owned // The nodes in a frame. Used only for debugging. // 一个 frame 中的节点。 仅用于调试。 std::vector<const Node*>* nodes; // Owned ~FrameInfo() { delete pending_counts; delete nodes; } }; static Status BuildControlFlowInfo(const Graph* graph, ControlFlowInfo* cf_info); void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); FrameInfo* EnsureFrameInfo(const string& fname) { auto slot = &frame_info_[fname]; if (*slot == nullptr) { *slot = new FrameInfo; } return *slot; } // Owned. LocalExecutorParams params_; const Graph* graph_; GraphView gview_; // A cached value of params_ bool device_record_tensor_accesses_ = false; // Root nodes (with no in edges) that should form the initial ready queue // 形成初始就绪队列的根节点(没有边) std::vector<const Node*> root_nodes_; // Mapping from frame name to static information about the frame. // TODO(yuanbyu): We could cache it along with the graph so to avoid // the overhead of constructing it for each executor instance. gtl::FlatMap<string, FrameInfo*, HashStr> frame_info_; TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); }; // Infer memory allocation attributes of a node n's output, based on its use node dst. // Note that dst might not be directly connected to n by a single edge, but might be a // downstream consumer of n's output by reference. *attr is updated with any necessary attributes. Status InferAllocAttr(const Node* n, const Node* dst, const DeviceNameUtils::ParsedName& local_dev_name, AllocatorAttributes* attr); GraphView::~GraphView() { static_assert(std::is_trivially_destructible<AllocatorAttributes>::value, "Update code if AllocatorAttributes gains a destructor"); static_assert(std::is_trivially_destructible<EdgeInfo>::value, "Update code if EdgeInfo gains a destructor"); for (int i = 0; i < num_nodes_; i++) { NodeItem* n = node(i); if (n != nullptr) { n->NodeItem::~NodeItem(); // Memory for "n" itself is held in space_ & gets cleaned up below } } delete[] node_offsets_; delete[] space_; } size_t GraphView::NodeItemBytes(const Node* n) { const int num_output_edges = n->out_edges().size(); const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); // Compute number of bytes needed for NodeItem and variable length data. // We do not subtract sizeof(var) since num_inputs/num_outputs might both be zero. // 计算 NodeItem 和可变长度数据所需的字节数。   // 我们不会减去 sizeof(var),因为 num_inputs/num_outputs 都可能为零。 const size_t raw_bytes = sizeof(NodeItem) // Fixed + num_output_edges * sizeof(EdgeInfo) // output_edges[...] + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] + num_inputs * sizeof(uint8) // input_type[num_inputs] + num_outputs * sizeof(uint8); // output_type[num_outputs] static constexpr size_t kItemAlignment = sizeof(NodeItem*); static_assert(kItemAlignment % alignof(NodeItem) == 0, "NodeItem must be aligned with kItemAlignment"); static_assert(kItemAlignment % alignof(EdgeInfo) == 0, "EdgeInfo must be aligned with kItemAlignment"); static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0, "AllocatorAttributes must be aligned with kItemAlignment"); static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0, "NodeItem must be aligned with EdgeInfo"); static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0, "NodeItem must be aligned with AllocatorAttributes"); static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0, "EdgeInfo must be aligned with AllocatorAttributes"); const size_t bytes = ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment; return bytes; } char* GraphView::InitializeNode(char* ptr, const Node* n) { const int id = n->id(); CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor const size_t bytes = NodeItemBytes(n); constexpr size_t kItemAlignment = sizeof(NodeItem*); CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0); NodeItem* item = reinterpret_cast<NodeItem*>(ptr); // We store a 32-bit offset relative to the beginning of space_, so that we only need an array // of 32-bit values to map from node id to the NodeItem*, (versus 64 bits on most machines if we // just stored an array of NodeItem* pointers). Casting to int64 is needed on 32bit CPU to // avoid comparing values as "int" vs "size_t" in CHECK_LE. CHECK_LE(static_cast<int64>(ptr - space_), kuint32max); const uint32 offset = ptr - space_; node_offsets_[id] = offset; ptr += bytes; const int num_output_edges = n->out_edges().size(); const int num_inputs = n->num_inputs(); const int num_outputs = n->num_outputs(); new (item) NodeItem(); item->num_inputs = num_inputs; item->num_outputs = num_outputs; item->num_output_edges = num_output_edges; // Fill output edges. 填充输出边。 // Keep track of the last EdgeInfo in the EdngeInfo array that references a given output slot. // For all but the last, we need to do a copy of the Tensor when propagating results downstream // in the graph, but for the last one, we can just do a move of the Tensor object to propagate it. gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr); EdgeInfo* dst_edge = item->output_edge_base(); for (auto e : n->out_edges()) { dst_edge->dst_id = e->dst()->id(); CHECK_LE(e->src_output(), ((int32)0x3FFFFFFF)); // Must fit in 31 bits dst_edge->output_slot = e->src_output(); dst_edge->is_last = false; const int output_slot = dst_edge->output_slot; if (output_slot >= 0) { last_indices[output_slot] = dst_edge; } dst_edge->input_slot = e->dst_input(); dst_edge++; } for (EdgeInfo* edge_info : last_indices) { if (edge_info != nullptr) { edge_info->is_last = true; } } AllocatorAttributes* output_attrs = item->output_attr_base(); for (int i = 0; i < num_outputs; i++) { new (&output_attrs[i]) AllocatorAttributes(); } DCHECK_LT(DataType_MAX, 255); // Must fit in uint8 uint8* input_types = item->input_type_base(); for (int i = 0; i < num_inputs; i++) { input_types[i] = static_cast<uint8>(n->input_type(i)); DCHECK_EQ(item->input_type(i), n->input_type(i)); } uint8* output_types = item->output_type_base(); for (int i = 0; i < num_outputs; i++) { output_types[i] = static_cast<uint8>(n->output_type(i)); DCHECK_EQ(item->output_type(i), n->output_type(i)); } return ptr; } void GraphView::Initialize(const Graph* g) { CHECK(node_offsets_ == nullptr); const int num_nodes = g->num_node_ids(); num_nodes_ = num_nodes; size_t total_bytes = 0; for (const Node* n : g->nodes()) { total_bytes += NodeItemBytes(n); } node_offsets_ = new uint32[num_nodes]; for (int i = 0; i < num_nodes; i++) { node_offsets_[i] = kuint32max; } space_ = new char[total_bytes]; // NodeItem objects are allocated here char* ptr = space_; for (const Node* n : g->nodes()) { ptr = InitializeNode(ptr, n); } CHECK_EQ(ptr, space_ + total_bytes); } void GetMaxPendingCounts(const Node* n, int* max_pending, int* max_dead_count) { const int num_in_edges = n->in_edges().size(); int initial_count; if (IsMerge(n)) { // merge waits all control inputs so we initialize the pending // count to be the number of control edges. // 合并等待所有控制输入,所以我们将待处理计数初始化为控制边的数量。 int32 num_control_edges = 0; for (const Edge* edge : n->in_edges()) { if (edge->IsControlEdge()) { num_control_edges++; } } // Use bit 0 to indicate if we are waiting for a ready live data input. initial_count = 1 + (num_control_edges << 1); } else { initial_count = num_in_edges; } *max_pending = initial_count; *max_dead_count = num_in_edges; } Status ExecutorImpl::Initialize() { gview_.Initialize(graph_); // Build the information about frames in this subgraph. // 在此子图中构建有关 frames 的信息。 ControlFlowInfo cf_info; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_, &cf_info)); // Cache this value so we make this virtual function call once, rather // that O(# steps * # nodes per step) times. device_record_tensor_accesses_ = params_.device->RequiresRecordingAccessedTensors(); for (auto& it : cf_info.unique_frame_names) { EnsureFrameInfo(it)->nodes = new std::vector<const Node*>; } // Preprocess every node in the graph to create an instance of op kernel for each node. // 预处理图中的每个节点,为每个节点创建一个 op 内核实例。 for (const Node* n : graph_->nodes()) { const int id = n->id(); const string& frame_name = cf_info.frame_names[id]; FrameInfo* frame_info = EnsureFrameInfo(frame_name); // See if this node is a root node, and if so, add to root_nodes_. // 看看这个节点是否是根节点,如果是这样,添加到 root_nodes_。 const int num_in_edges = n->in_edges().size(); if (num_in_edges == 0) { root_nodes_.push_back(n); } NodeItem* item = gview_.node(id); item->node = n; item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); Status s = params_.create_kernel(n->def(), &item->kernel); if (!s.ok()) { item->kernel = nullptr; s = AttachDef(s, n->def()); LOG(ERROR) << "Executor failed to create kernel. " << s; return s; } CHECK(item->kernel); item->kernel_is_expensive = item->kernel->IsExpensive(); item->kernel_is_async = (item->kernel->AsAsync() != nullptr); item->is_merge = IsMerge(n); item->is_enter = IsEnter(n); item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_sink = IsSink(n); item->is_enter_exit_or_next_iter = (IsEnter(n) || IsExit(n) || IsNextIteration(n)); // Compute the maximum values we'll store for this node in the pending counts data structure, // and allocate a handle in that frame's pending counts data structure that has enough // space to store these maximal count values. // 计算我们将在此等待计数数据结构中为该节点存储的最大值, // 并在该 frame 的待处理计数数据结构中分配一个句柄,该结构具有足够的空间来存储这些最大计数值。 int max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); item->pending_id = frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); // Initialize static information about the frames in the graph. // 初始化关于图中 frame 的静态信息。 frame_info->nodes->push_back(n); if (IsEnter(n)) { string enter_name; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; } } // Initialize PendingCounts only after item->pending_id is initialized for all nodes. // 只有在所有节点都初始化了 item->pending_id 之后才初始化 PendingCounts。 InitializePending(graph_, cf_info); return gview_.SetAllocAttrs(graph_, params_.device); } Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { Status s; DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); for (const Node* n : g->nodes()) { NodeItem* item = node(n->id()); AllocatorAttributes* attrs = item->output_attr_base(); // Examine the out edges of each node looking for special use // cases that may affect memory allocation attributes. // 检查每个节点的输出边,寻找可能影响内存分配属性的特殊用例。 for (auto e : n->out_edges()) { if (!e->IsControlEdge()) { AllocatorAttributes attr; s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); if (!s.ok()) return s; if (attr.value != 0) { attrs[e->src_output()].Merge(attr); } } } for (int out = 0; out < n->num_outputs(); out++) { const OpKernel* op_kernel = item->kernel; DCHECK_LT(out, op_kernel->output_memory_types().size()); bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; if (on_host) { AllocatorAttributes h; h.set_on_host(on_host); attrs[out].Merge(h); } } } return s; } Status InferAllocAttr(const Node* n, const Node* dst, const DeviceNameUtils::ParsedName& local_dev_name, AllocatorAttributes* attr) { Status s; // Note that it's possible for *n to be a Recv and *dst to be a Send, // so these two cases are not mutually exclusive. // 请注意,*n 可能是 Recv 节点, *dst 可能是 Send 节点,因此这两种情况不是互斥的。 if (IsRecv(n)) { string src_name; s = GetNodeAttr(n->def(), "send_device", &src_name); if (!s.ok()) return s; DeviceNameUtils::ParsedName parsed_src_name; if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) { s = errors::Internal("Bad send_device attr '", src_name, "' in node ", n->name()); return s; } if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) { // Value is going to be the sink of an RPC. 值将是 RPC 的接收器 attr->set_nic_compatible(true); VLOG(2) << "node " << n->name() << " is the sink of an RPC in"; } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) && parsed_src_name.type != "CPU") { // Value is going to be the sink of a local DMA from GPU to CPU (or other types of accelerators). // 值将是从 GPU 到 CPU(或其他类型的加速器)的本地 DMA 的接收器。 attr->set_gpu_compatible(true); VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy"; } else { VLOG(2) << "default alloc case local type " << local_dev_name.type << " remote type " << parsed_src_name.type; } } if (IsSend(dst)) { string dst_name; s = GetNodeAttr(dst->def(), "recv_device", &dst_name); if (!s.ok()) return s; DeviceNameUtils::ParsedName parsed_dst_name; if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", n->name()); return s; } if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { // Value is going to be the source of an RPC. 值将成为 RPC 的来源。 attr->set_nic_compatible(true); VLOG(2) << "node " << n->name() << " is the source of an RPC out"; } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) && parsed_dst_name.type != "CPU") { // Value is going to be the source of a local DMA from CPU to GPU (or other types of accelerators). // Note that this does not cover the case where the allocation of the // output tensor is not generated by the src: n. // 值将是从 CPU 到 GPU(或其他类型的加速器)的本地 DMA 的来源。 // 请注意,这不包括 src: n 不生成输出张量的分配的情况。 attr->set_gpu_compatible(true); VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; } else { VLOG(2) << "default alloc case local type " << local_dev_name.type << " remote type " << parsed_dst_name.type; } } return s; } // The state associated with one invocation of ExecutorImpl::Run. // ExecutorState dispatches nodes when they become ready and keeps // track of how many predecessors of a node have not done (pending_). // 状态与执行一个 ExecutorImpl::Run 的调用相关联。 // ExecutorState 在准备就绪时调度节点,并跟踪节点尚未完成的前几个(pending_)。 class ExecutorState { public: ExecutorState(const Executor::Args& args, ExecutorImpl* impl); ~ExecutorState(); void RunAsync(Executor::DoneCallback done); private: // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). // 张量指针(pass-by-reference)或张量(pass-by-value)。 // TODO(yuanbyu): A better way to do "has_value"? struct Entry { Entry() {} Entry(const Entry& other) : ref(other.ref), ref_mu(other.ref_mu), has_value(other.has_value), val_field_is_set(other.val_field_is_set), alloc_attr(other.alloc_attr), device_context(other.device_context) { if (val_field_is_set) { val.Init(*other.val); } } ~Entry() { if (val_field_is_set) val.Destroy(); } Entry& operator=(const Entry& other) { if (val_field_is_set) { val.Destroy(); } ref = other.ref; ref_mu = other.ref_mu; has_value = other.has_value; val_field_is_set = other.val_field_is_set; alloc_attr = other.alloc_attr; device_context = other.device_context; if (val_field_is_set) { val.Init(*other.val); } return *this; } Entry& operator=(Entry&& other) { if (val_field_is_set) { val.Destroy(); } ref = other.ref; ref_mu = other.ref_mu; has_value = other.has_value; val_field_is_set = other.val_field_is_set; alloc_attr = other.alloc_attr; device_context = other.device_context; if (val_field_is_set) { val.Init(std::move(*other.val)); } return *this; } // Clears the <val> field. 清除<val>字段。 void ClearVal() { if (val_field_is_set) { val.Destroy(); val_field_is_set = false; } } // A tensor value, if val_field_is_set. ManualConstructor<Tensor> val; Tensor* ref = nullptr; // A tensor reference. 张量引用 mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. // Whether the value exists, either in <val> or <ref>. // 在 <val> 或 <ref> 中,该值是否存在。 bool has_value = false; bool val_field_is_set = false; // The attributes of the allocator that creates the tensor. // 创建张量的分配器的属性。 AllocatorAttributes alloc_attr; // Every entry carries an optional DeviceContext containing // Device-specific information about how the Tensor was produced. // 每个 entry 都带有可选的 DeviceContext, 其中包含有关如何生成 Tensor 的特定设备的信息。 DeviceContext* device_context = nullptr; }; // Contains a value for [node->id()] for the device context assigned by the // device at the beginning of a step. // 在步骤开始时,包含设备中分配的设备上下文的 [node->id()] 的值。 DeviceContextMap device_context_map_; struct TaggedNode; typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; typedef gtl::InlinedVector<Entry, 4> EntryVector; struct IterationState { explicit IterationState(const PendingCounts* pending_counts, int total_input_tensors) : input_tensors(new Entry[total_input_tensors]), outstanding_ops(0), outstanding_frame_count(0), counts_(*pending_counts) { // Initialize with copy of *pending_counts } // The state of an iteration. 迭代的状态。 // One copy per iteration. For iteration k, i-th node's j-th input is in // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either // a tensor pointer (pass-by-reference) or a tensor (pass-by-value). // 每次迭代一个副本。 对于迭代 k, 第 i 个节点的第 j 个输入在 input_tensors[k][impl_->nodes[i].input_start + j] 中。 // entry 是一个张量指针(pass-by-reference)或一个张量(pass-by-value)。 // // NOTE: No need to protect input_tensors[i] by any locks because it // is resized once. Each element of tensors_ is written once by the // source node of an edge and is cleared by the destination of the same // edge. The latter node is never run concurrently with the former node. // 注意:由于调整了一次大小,因此不需要使用任何锁来保护 input_tensors[i]。 // tensors_ 的每个元素由边缘的源节点写入一次,并被同一边缘的目的地清除。 // 后一个节点从不与前一个节点同时运行。 Entry* input_tensors; // The number of outstanding ops for each iteration. 每次迭代的未完成的操作数。 int outstanding_ops; // The number of outstanding frames for each iteration. // 每次迭代的未完成的 frames。 int outstanding_frame_count; int pending(PendingCounts::Handle h) { return counts_.pending(h); } int decrement_pending(PendingCounts::Handle h, int v) { return counts_.decrement_pending(h, v); } // Mark a merge node as live // REQUIRES: Node corresponding to "h" is a merge node void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); } // Mark a node to show that processing has started. void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); } // Mark a node to show that processing has completed. void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); } PendingCounts::NodeState node_state(PendingCounts::Handle h) { return counts_.node_state(h); } int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); } void increment_dead_count(PendingCounts::Handle h) { counts_.increment_dead_count(h); } void adjust_for_activation(PendingCounts::Handle h, bool increment_dead, int* pending_result, int* dead_result) { counts_.adjust_for_activation(h, increment_dead, pending_result, dead_result); } ~IterationState() { delete[] input_tensors; } private: PendingCounts counts_; }; struct FrameState { explicit FrameState(const ExecutorImpl* impl, int parallel_iters) : executor(impl), max_parallel_iterations(parallel_iters), num_outstanding_iterations(1) {} // A new frame is created for each loop. Execution starts at iteration 0. // When a value at iteration 0 passes through a NextIteration node, iteration 1 is created // and starts running. // Note that iteration 0 may still be running so multiple iterations may run in parallel. // The frame maintains the state of iterations in several data structures such as pending_count // and input_tensors. When iteration 0 completes, we garbage collect the state of iteration 0. // // A frame instance is considered "done" and can be garbage collected if all its inputs have // entered and all its iterations are "done". // // A frame manages the live iterations of an iterative computation. // Iteration i is considered "done" when there are no outstanding ops, frames at iteration i // are done, all recvs for this iteration are completed, and iteration i-1 is done. // For iteration 0, we instead wait for there to be no more pending inputs of the frame. // // Frames and iterations are garbage collected once they are done. // The state we need to keep around is highly dependent on the parallelism enabled by the scheduler. // We may want to have the scheduler dynamically control the outstanding number of live parallel // frames and iterations. To reduce the state space, the scheduler might want to schedule ops // in inner frames first and lower iterations first. // // This frame state is mostly initialized lazily on demand so we don't introduce unnecessary overhead. // The executor the frame is in. const ExecutorImpl* executor = nullptr; // The name of this frame, which is the concatenation of its parent frame name, the iteration // of the parent frame when this frame was created, and the value of the attr 'frame_name'. string frame_name; // The unique id for this frame. Generated by fingerprinting frame_name. uint64 frame_id; // The iteration id of its parent frame when this frame is created. // -1 if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. int64 parent_iter = -1; // The FrameState of its parent frame. FrameState* parent_frame = nullptr; // The maximum allowed number of parallel iterations. const int max_parallel_iterations; // The number of inputs this frame is still waiting. int num_pending_inputs = 0; // The highest iteration number we have reached so far in this frame. int64 iteration_count GUARDED_BY(mu) = 0; // The number of outstanding iterations. int num_outstanding_iterations GUARDED_BY(mu) = 1; // The active iteration states of this frame. gtl::InlinedVector<IterationState*, 12> iterations; // The NextIteration nodes to enter a new iteration. If the number of outstanding iterations // reaches the limit, we will defer the start of the next iteration until the number of // outstanding iterations falls below the limit. std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu); // The values of the loop invariants for this loop. They are added into this list as they // "enter" the frame. When a loop invariant enters, we make it available to all active // iterations. When the frame starts a new iteration, we make all the current loop invariants // available to the new iteration. std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu); // The list of dead exit nodes for the current highest iteration. We // will only "execute" the dead exits of the final iteration. std::vector<const Node*> dead_exits GUARDED_BY(mu); // Static information specific to this frame. PendingCounts* pending_counts = nullptr; int total_input_tensors = 0; std::vector<const Node*>* nodes = nullptr; // Lock ordering: ExecutorState.mu_ < mu. mutex mu; void InitializeFrameInfo(const string& enter_name) { auto it_frame_info = executor->frame_info_.find(enter_name); DCHECK(it_frame_info != executor->frame_info_.end()); ExecutorImpl::FrameInfo* finfo = it_frame_info->second; pending_counts = finfo->pending_counts; total_input_tensors = finfo->total_inputs; num_pending_inputs = finfo->input_count; nodes = finfo->nodes; } inline IterationState* GetIteration(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu) { int index = iter % iterations.size(); return iterations[index]; } inline void SetIteration(int64 iter, IterationState* state) EXCLUSIVE_LOCKS_REQUIRED(mu) { int index = iter % iterations.size(); DCHECK(state == nullptr || iterations[index] == nullptr); iterations[index] = state; } // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. // 减少未完成的操作数,并清理框架中的迭代。 如果框架执行完成,返回 true。 inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) { mutex_lock l(mu); return DecrementOutstandingOpsLocked(gview, iter, ready); } // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. // 减少未完成的操作数,并清理框架中的迭代。 如果框架执行完成,返回 true。 inline bool DecrementOutstandingOpsLocked(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu) { IterationState* istate = GetIteration(iter); istate->outstanding_ops--; if (istate->outstanding_ops != 0) { return false; } else { return CleanupIterations(gview, iter, ready); } } // Returns true if the computation in the frame is completed. // 如果帧中的计算完成,则返回true。 inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) { return (num_pending_inputs == 0 && num_outstanding_iterations == 0); } // Returns true if the iteration of the frame is completed. bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu); // Increments the iteration id. If this is a new iteration, initialize it. // 增加迭代 ID。 如果这是一个新的迭代,初始化它。 void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); // Activate all the deferred NextIteration nodes in a new iteration. // 在新的迭代中激活所有延迟的NextIteration节点。 void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); // Activate all the current loop invariants in a new iteration. void ActivateLoopInvs(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); // Add a new loop invariant and make it available to all active iterations. void AddLoopInv(const NodeItem* item, const Entry& value, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); // Activate the successors of a node. Contents of *outputs are left in an // indeterminate state after returning from this method. // 激活一个节点的后继。 *outputs 的内容在从该方法返回后保持不确定状态。 void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, EntryVector* outputs, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); // Cleanup iterations of this frame starting from iteration iter. bool CleanupIterations(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu); ~FrameState() { for (size_t i = 0; i < iterations.size(); ++i) { delete iterations[i]; iterations[i] = nullptr; } } }; // A tagged node: <frame*, iter, node*>. struct TaggedNode { const Node* node = nullptr; FrameState* input_frame = nullptr; int64 input_iter = -1; bool is_dead = false; TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter, bool dead) { node = t_node; input_frame = in_frame; input_iter = in_iter; is_dead = dead; } }; // A drop-in replacement for std::deque<TaggedNode>. We typically don't // have that many nodes in the ready queue, so we just use a vector and // don't free up memory from the queue as we consume nodes. class TaggedNodeReadyQueue { public: TaggedNodeReadyQueue() : front_index_(0) {} void push_back(TaggedNode node) { ready_.push_back(node); } TaggedNode front() const { DCHECK_LT(front_index_, ready_.size()); return ready_[front_index_]; } void pop_front() { DCHECK_LT(front_index_, ready_.size()); front_index_++; if ((front_index_ == ready_.size()) || (front_index_ > 16384)) { if (front_index_ == ready_.size()) { ready_.clear(); } else { // Lots of unused entries at beginning of vector: move everything down // to start of vector. ready_.erase(ready_.begin(), ready_.begin() + front_index_); } front_index_ = 0; } } bool empty() const { return ready_.empty(); } const TaggedNode* begin() const { return ready_.begin() + front_index_; } const TaggedNode* end() const { return ready_.end(); } private: gtl::InlinedVector<TaggedNode, 16> ready_; int front_index_; }; struct AsyncState; const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply. // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply. const bool log_memory_; int64 step_id_; // Not owned. Rendezvous* rendezvous_; SessionState* session_state_; TensorStore* tensor_store_; // Step-local container. ScopedStepContainer* step_container_; StepStatsCollector* stats_collector_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; FunctionCallFrame* call_frame_; const ExecutorImpl* impl_; CancellationManager* cancellation_manager_; Executor::Args::Runner runner_; bool sync_on_finish_; // Owned. // A flag that is set on error after the frame state has been // dumped for diagnostic purposes. bool dumped_on_error_ = false; // The root frame in which the execution of this step is started. FrameState* root_frame_; // Invoked when the execution finishes. Executor::DoneCallback done_cb_; std::atomic_int_fast32_t num_outstanding_ops_; mutex mu_; Status status_ GUARDED_BY(mu_); // Mapping from frame name to outstanding frames. A new frame is created // at some iteration of an active frame. So the unique key for the new // child frame is composed of the name of the parent frame, the iteration // number at which the parent frame is creating the new frame, and the // name of the new frame from nodedef. gtl::FlatMap<string, FrameState*, HashStr> outstanding_frames_ GUARDED_BY(mu_); // The unique name of a frame. inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) { return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); } // Find an existing or create a new child frame in the frame 'frame' at iteration 'iter'. void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child); // Delete a frame. Called when the frame is done. void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); // Cleanup frames and iterations starting from frame/iter. Called when a child frame is done. void CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready); // Process a ready node in current thread. // 在当前线程中处理可用节点。 void Process(TaggedNode node, int64 scheduled_usec); // Before invoking item->kernel, fills in its "inputs". // 在调用 item->kernel 之前,填入它的"inputs"。 Status PrepareInputs(const NodeItem& item, Entry* first_input, TensorValueVec* inputs, DeviceContextVec* input_device_contexts, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead); // After item->kernel computation is done, processes its outputs. // 在 item->kernel 计算完成后,处理其输出。 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, NodeExecStats* stats); // After processing the outputs, propagates the outputs to their dsts. // Contents of *outputs are left in an indeterminate state after returning from this method. // 在处理输出后,传播输出数据到 dsts。从该方法返回后,*outputs 的内容处于不确定状态。 void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, EntryVector* outputs, TaggedNodeSeq* ready); // "node" just finishes. Takes ownership of "stats". Returns true if execution has completed. // "node" 完成。取得"stats"的所有权。 如果执行完成,则返回 true。 bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. // 安排'ready'中所有 expensive nodes ,并将'ready'中的所有 inexpensive nodes 放入'inline_ready'中。 void ScheduleReady(const TaggedNodeSeq& ready, TaggedNodeReadyQueue* inline_ready); // For debugging/logging only. inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id); // Provide debugging output about an outstanding node in the executor. // 提供执行器中未完成节点的调试输出。 void DumpPendingNodeState(const int node_id, const Entry* input_vector, bool show_nodes_with_no_ready_inputs); void DumpActiveNodeState(const int node_id, const Entry* input_vector); // Provide debugging output about an outstanding iteration in the executor. // 提供执行器中未完成的迭代的调试输出。 void DumpIterationState(const FrameState* frame, IterationState* iteration); // Provide debugging output of the state of the executor. // 提供执行器状态的调试输出。 void DumpState(); const Tensor* GetTensorValueForDump(const Entry& input); // Clean up when this executor is done. // 当这个执行者完成时清理。 void Finish(); // A standalone routine for this expression so that we can express // that we don't want thread safety analysis on this reference (it's // safe to do without the lock because the iterations array never // resizes and this particular iteration's array element will not // be changed out from under us because the iteration is still alive). Entry* GetInputTensors(FrameState* input_frame, int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS { return input_frame->GetIteration(input_iter)->input_tensors; } }; ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) : vlog_(VLOG_IS_ON(1)), log_memory_(LogMemory::IsEnabled()), step_id_(args.step_id), rendezvous_(args.rendezvous), session_state_(args.session_state), tensor_store_(args.tensor_store), step_container_(args.step_container), stats_collector_(args.stats_collector), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), cancellation_manager_(args.cancellation_manager), runner_(args.runner), sync_on_finish_(args.sync_on_finish), num_outstanding_ops_(0) { // We start the entire execution in iteration 0 of the root frame // so let us create the root frame and the state for iteration 0. // We assume root_frame_->frame_name.empty(). // 我们从根 frame 的第 0 次迭代的开始整个的执行过程, // 所以我们创建根 frame 和迭代状态 0。我们假设 root_frame_->frame_name.empty()。 root_frame_ = new FrameState(impl_, 1); root_frame_->frame_id = 0; // must be 0 root_frame_->InitializeFrameInfo(root_frame_->frame_name); // Initialize iteration 0. 初始化迭代 0 root_frame_->iterations.resize(root_frame_->max_parallel_iterations); root_frame_->iterations[0] = new IterationState( root_frame_->pending_counts, root_frame_->total_input_tensors); outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); } ExecutorState::~ExecutorState() { for (auto name_frame : outstanding_frames_) { delete name_frame.second; } for (auto it : device_context_map_) { it->Unref(); } delete slice_reader_cache_; } Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, ControlFlowInfo* cf_info) { const int num_nodes = g->num_node_ids(); cf_info->frame_names.resize(num_nodes); std::vector<Node*> parent_nodes; parent_nodes.resize(num_nodes); std::vector<bool> visited; visited.resize(num_nodes); string frame_name; std::deque<Node*> ready; // Initialize with the root nodes. 初始化根节点 for (Node* n : g->nodes()) { if (n->in_edges().empty()) { visited[n->id()] = true; cf_info->unique_frame_names.insert(frame_name); ready.push_back(n); } } while (!ready.empty()) { Node* curr_node = ready.front(); int curr_id = curr_node->id(); ready.pop_front(); Node* parent = nullptr; if (IsEnter(curr_node)) { // Enter a child frame. TF_RETURN_IF_ERROR(GetNodeAttr(curr_node->def(), "frame_name", &frame_name)); parent = curr_node; } else if (IsExit(curr_node)) { // Exit to the parent frame. parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[parent->id()]; parent = parent_nodes[parent->id()]; } else { parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[curr_id]; } for (const Edge* out_edge : curr_node->out_edges()) { Node* out = out_edge->dst(); int out_id = out->id(); // Add to ready queue if not visited. bool is_visited = visited[out_id]; if (!is_visited) { ready.push_back(out); visited[out_id] = true; // Process the node 'out'. cf_info->frame_names[out_id] = frame_name; parent_nodes[out_id] = parent; cf_info->unique_frame_names.insert(frame_name); } } } return Status::OK(); } void ExecutorImpl::InitializePending(const Graph* graph, const ControlFlowInfo& cf_info) { for (auto& it : cf_info.unique_frame_names) { FrameInfo* finfo = EnsureFrameInfo(it); PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout); DCHECK_EQ(finfo->pending_counts, nullptr); finfo->pending_counts = counts; } for (const Node* n : graph->nodes()) { const int id = n->id(); const string& name = cf_info.frame_names[id]; int max_pending, max_dead; GetMaxPendingCounts(n, &max_pending, &max_dead); const NodeItem* item = gview_.node(id); PendingCounts* counts = EnsureFrameInfo(name)->pending_counts; counts->set_initial_count(item->pending_id, max_pending); } } // 异步运行 void ExecutorState::RunAsync(Executor::DoneCallback done) { const Graph* graph = impl_->graph_; TaggedNodeSeq ready; // Ask the device to fill in the device context map. // 请求设备填写设备上下文映射。 Device* device = impl_->params_.device; Status fill_status = device->FillContextMap(graph, &device_context_map_); if (!fill_status.ok()) { done(fill_status); return; } // Initialize the ready queue. // 初始化就绪队列。 for (const Node* n : impl_->root_nodes_) { DCHECK_EQ(n->in_edges().size(), 0); ready.push_back(TaggedNode{n, root_frame_, 0, false}); } if (ready.empty()) { done(Status::OK()); } else { num_outstanding_ops_ = ready.size(); root_frame_->iterations[0]->outstanding_ops = ready.size(); done_cb_ = done; // Schedule to run all the ready ops in thread pool. // 在线程池中运行所有准备好的操作的计划。 ScheduleReady(ready, nullptr); } } // State kept alive for executing an asynchronous node in another thread. // NOTE: We need to make a copy of p.input, p.input_device_contexts, and // p.input_alloc_attrs for asynchronous kernels because OpKernelContext // methods like input_type(i) needs the param points to valid input type vector. // It's not an issue for sync kernels because these vectors are kept on the stack. // 状态保持活动,在另一个线程中执行异步节点。 // 注意:我们需要为异步内核创建一个p.input, p.input_device_contexts 和 p.input_alloc_attrs 的副本, // 因为像 input_type(i) 这样的 OpKernelContext 方法需要 param 指向有效的输入类型向量。 // 同步内核不是问题,因为这些向量保存在堆栈上。 struct ExecutorState::AsyncState { AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, const NodeItem* _item, Entry* _first_input, NodeExecStats* _stats) : saved_inputs(*p.inputs), saved_input_device_contexts(*p.input_device_contexts), saved_input_alloc_attrs(*p.input_alloc_attrs), params(p), tagged_node(_tagged_node), item(_item), first_input(_first_input), // ParamsButClearingEigenGPUDevice does equivalent of params.eigen_gpu_device = nullptr; // ParamsButClearingEigenGPUDevice 相当于 params.eigen_gpu_device = nullptr; ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs), stats(_stats) { params.inputs = &saved_inputs; params.input_device_contexts = &saved_input_device_contexts; params.input_alloc_attrs = &saved_input_alloc_attrs; } TensorValueVec saved_inputs; DeviceContextVec saved_input_device_contexts; AllocatorAttributeVec saved_input_alloc_attrs; OpKernelContext::Params params; TaggedNode tagged_node; const NodeItem* item; Entry* first_input; OpKernelContext ctx; NodeExecStats* stats; private: OpKernelContext::Params* ParamsButClearingEigenGPUDevice( OpKernelContext::Params* p) { // Ensure OpKernelContext constructor will make a new eigen GPU device if necessary. // 如果需要,请确保 OpKernelContext 构造函数将创建一个新的 eigen GPU 设备。 p->eigen_gpu_device = nullptr; // Force allocation return p; } }; void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; TaggedNodeReadyQueue inline_ready; // Parameters passed to OpKernel::Compute. 参数传递给 OpKernel::Compute TensorValueVec inputs; DeviceContextVec input_device_contexts; AllocatorAttributeVec input_alloc_attrs; OpKernelContext::Params params; params.step_id = step_id_; Device* device = impl_->params_.device; params.device = device; params.log_memory = log_memory_; params.record_tensor_accesses = impl_->device_record_tensor_accesses_; params.rendezvous = rendezvous_; params.session_state = session_state_; params.tensor_store = tensor_store_; params.cancellation_manager = cancellation_manager_; params.call_frame = call_frame_; params.function_library = impl_->params_.function_library; params.resource_manager = device->resource_manager(); params.step_container = step_container_; params.slice_reader_cache = slice_reader_cache_; params.inputs = &inputs; params.input_device_contexts = &input_device_contexts; params.input_alloc_attrs = &input_alloc_attrs; params.runner = &runner_; Status s; NodeExecStats* stats = nullptr; EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); while (!inline_ready.empty()) { tagged_node = inline_ready.front(); inline_ready.pop_front(); const Node* node = tagged_node.node; FrameState* input_frame = tagged_node.input_frame; int64 input_iter = tagged_node.input_iter; const int id = node->id(); const NodeItem& item = *gview.node(id); // TODO(misard) Replace with a finer-grain enabling flag once we // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { mutex_lock l(input_frame->mu); input_frame->GetIteration(input_iter)->mark_started(item.pending_id); } // Set the device_context for this node id, if it exists. if (id < device_context_map_.size()) { params.op_device_context = device_context_map_[id]; } params.track_allocations = false; stats = nullptr; if (stats_collector_ && !tagged_node.is_dead) { // track allocations if and only if we are collecting statistics // 跟踪分配,当且仅当我们收集统计数据 params.track_allocations = true; stats = new NodeExecStats; stats->set_node_name(node->name()); nodestats::SetScheduled(stats, scheduled_usec); nodestats::SetAllStart(stats); } if (vlog_) { VLOG(1) << "Process node: " << id << " step " << params.step_id << " " << SummarizeNodeDef(node->def()) << " is dead: " << tagged_node.is_dead; } Entry* input_tensors = GetInputTensors(input_frame, input_iter); Entry* first_input = input_tensors + item.input_start; outputs.clear(); TensorReferenceVector accessed_tensors; DeviceContext* device_context = nullptr; // Only execute this node if it is not dead or it is a send/recv transfer node. // For transfer nodes, we need to propagate the "dead" bit even when the node is dead. // 只有在该节点没有 dead 或者该节点是 send/recv 传输节点的情况下才执行。 // 对于传输节点,即使节点死亡,我们也需要传播 "dead" 位。 bool launched_asynchronously = false; if (tagged_node.is_dead && !IsTransferNode(node)) { outputs.resize(item.num_outputs); } else { // Prepares inputs. 准备输入 bool is_input_dead = false; s = PrepareInputs(item, first_input, &inputs, &input_device_contexts, &input_alloc_attrs, &is_input_dead); if (!s.ok()) { // Clear inputs. 清除输入 int num_inputs = item.num_inputs; for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } MaybeMarkCompleted(input_frame, input_iter, id); // Continue to process the nodes in 'inline_ready'. // 继续处理 'inline_ready' 中的节点。 completed = NodeDone(s, item.node, ready, stats, &inline_ready); continue; } // Set up compute params. 设置计算参数。 OpKernel* op_kernel = item.kernel; params.op_kernel = op_kernel; params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); params.is_input_dead = is_input_dead; params.output_attr_array = item.output_attrs(); // 异步计算 if (item.kernel_is_async) { // Asynchronous computes. AsyncOpKernel* async = item.kernel->AsAsync(); DCHECK(async != nullptr); launched_asynchronously = true; AsyncState* state = new AsyncState(params, tagged_node, &item, first_input, stats); auto done = [this, state]() { Device* device = impl_->params_.device; NodeExecStats* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand if (vlog_) { VLOG(2) << this << " Async kernel done: " << SummarizeNodeDef(state->item->node->def()); } if (stats) nodestats::SetOpEnd(stats); EntryVector outputs; // After item->kernel computation is done, processes its outputs. // 在 item->kernel 计算完成后,处理其输出。 Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats); if (stats) nodestats::SetMemory(stats, &state->ctx); // Clears inputs. 清除输入 const int num_inputs = state->item->num_inputs; for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } FrameState* input_frame = state->tagged_node.input_frame; const int64 input_iter = state->tagged_node.input_iter; const int id = state->tagged_node.node->id(); // For debugging/logging only. // 仅用于调试/记录。 MaybeMarkCompleted(input_frame, input_iter, id); TaggedNodeSeq ready; if (s.ok()) { // void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, // EntryVector* outputs, TaggedNodeSeq* ready); // After processing the outputs, propagates the outputs to their dsts. // Contents of *outputs are left in an indeterminate state after // returning from this method. // 处理输出后,将输出数据传播到它们的 dsts。 *outputs 的内容在从该方法返回后保持不确定状态。 PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); } outputs.clear(); if (s.ok() && impl_->device_record_tensor_accesses_) { // Get the list of all tensors accessed during the execution // 获取在执行期间访问的所有张量的列表 TensorReferenceVector accessed; state->ctx.retrieve_accessed_tensors(&accessed); if (stats) nodestats::SetReferencedTensors(stats, accessed); // callee takes ownership of the vector // 被调用者拥有向量的所有权 device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(), accessed); } bool completed = NodeDone(s, state->item->node, ready, stats, nullptr); delete state; if (completed) Finish(); }; if (stats) nodestats::SetOpStart(stats); // 执行异步计算 device->ComputeAsync(async, &state->ctx, done); } // 同步计算 else { // Synchronous computes. OpKernelContext ctx(¶ms, item.num_outputs); if (stats) nodestats::SetOpStart(stats); // 执行同步计算 device->Compute(CHECK_NOTNULL(op_kernel), &ctx); if (stats) nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, &outputs, stats); if (s.ok() && impl_->device_record_tensor_accesses_) { // Get the list of all tensors accessed during the execution // 获取在执行期间访问的所有张量的列表 ctx.retrieve_accessed_tensors(&accessed_tensors); device_context = ctx.op_device_context(); } if (stats) nodestats::SetMemory(stats, &ctx); } } if (!launched_asynchronously) { // Clears inputs. const int num_inputs = item.num_inputs; for (int i = 0; i < num_inputs; ++i) { (first_input + i)->ClearVal(); } MaybeMarkCompleted(input_frame, input_iter, id); // Propagates outputs. // 传播 outputs if (s.ok()) { PropagateOutputs(tagged_node, &item, &outputs, &ready); } outputs.clear(); if (!accessed_tensors.empty()) { if (stats) nodestats::SetReferencedTensors(stats, accessed_tensors); // device_context is set above in synchronous computes // device_context 设置在同步计算中 device->ConsumeListOfAccessedTensors(device_context, accessed_tensors); } if (stats) { scheduled_usec = nodestats::NowInUsec(); } // Postprocess. 后期处理 completed = NodeDone(s, item.node, ready, stats, &inline_ready); } } // while !inline_ready.empty() // This thread of computation is done if completed = true. // 如果 completed = true,则此线程的计算完成。 if (completed) Finish(); } Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, TensorValueVec* inputs, DeviceContextVec* input_device_contexts, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { const Node* node = item.node; inputs->clear(); inputs->resize(item.num_inputs); input_device_contexts->clear(); input_device_contexts->resize(item.num_inputs); input_alloc_attrs->clear(); input_alloc_attrs->resize(item.num_inputs); *is_input_dead = false; bool is_merge = item.is_merge; for (int i = 0; i < item.num_inputs; ++i) { const bool expect_ref = IsRefType(item.input_type(i)); Entry* entry = first_input + i; (*input_device_contexts)[i] = entry->device_context; (*input_alloc_attrs)[i] = entry->alloc_attr; // i-th input. TensorValue* inp = &(*inputs)[i]; // Only merge and transfer nodes can have no-value inputs. if (!entry->has_value) { if (!is_merge) { DCHECK(IsTransferNode(node)); DCHECK(!entry->val_field_is_set); entry->has_value = true; entry->val_field_is_set = true; entry->val.Init(*kEmptyTensor); inp->tensor = entry->val.get(); *is_input_dead = true; } continue; } if (entry->ref == nullptr) { if (expect_ref) { return AttachDef( errors::InvalidArgument(i, "-th input expects a ref type"), item.kernel->def()); } inp->tensor = entry->val.get(); } else { if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { return AttachDef( errors::FailedPrecondition("Attempting to use uninitialized value ", item.kernel->def().input(i)), item.kernel->def()); } if (expect_ref) { inp->mutex_if_ref = entry->ref_mu; inp->tensor = entry->ref; } else { // Automatically deref the tensor ref when the op expects a tensor but is given a // ref to a tensor. Need to deref it under the mutex. { mutex_lock l(*(entry->ref_mu)); DCHECK(!entry->val_field_is_set); entry->val.Init(*entry->ref); entry->val_field_is_set = true; } entry->ref = nullptr; entry->ref_mu = nullptr; inp->tensor = entry->val.get(); } } } return Status::OK(); } Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, NodeExecStats* stats) { const Node* node = item.node; DCHECK_EQ(0, outputs->size()); outputs->resize(item.num_outputs); Status s = ctx->status(); if (!s.ok()) { s = AttachDef(s, item.kernel->def()); // TODO(misard) Replace with a finer-grain enabling flag once we // add better optional debugging support. // 一旦我们添加更好的可选调试支持,更换为更细粒度的启用标志。 if (vlog_ && VLOG_IS_ON(1)) { LOG(WARNING) << this << " Compute status: " << s; DumpState(); } return s; } // Get the device_context for this node id, if it exists. // 获取此节点 ID 的 device_context (如果存在)。 DeviceContext* device_context = nullptr; if (node->id() < device_context_map_.size()) { device_context = device_context_map_[node->id()]; } // Experimental: debugger (tfdb) access to intermediate node completion. // 实验:调试器(tfdb)访问中间节点完成。 if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) { // If the node has no output, invoke the callback with output slot set to // -1, signifying that this is a no-output node. // 如果节点没有输出,调用回调,这个回调的输出设置为-1,表示这是一个无输出节点 s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr, false, ctx)); } for (int i = 0; i < item.num_outputs; ++i) { TensorValue val = ctx->release_output(i); if (*ctx->is_output_dead() || val.tensor == nullptr) { // Unless it's a Switch or a Recv, the node must produce a // tensor value at i-th output. // 除非是 Switch 或 Recv 节点,否则节点必须在第 i 个输出端产生张量值。 if (!IsSwitch(node) && !IsRecv(node)) { s.Update(errors::Internal("Missing ", i, "-th output from ", SummarizeNodeDef(node->def()))); } } else { Entry* out = &((*outputs)[i]); // Set the device context of the output entry. // 设置输出条目的设备上下文。 out->device_context = device_context; // Set the allocator attributes of the output entry. // 设置输出条目的 allocator 属性。 out->alloc_attr = ctx->output_alloc_attr(i); // Sanity check of output tensor types. // 输出张量类型的健全检查。 DataType dtype = val->dtype(); if (val.is_ref()) dtype = MakeRefType(dtype); if (dtype == item.output_type(i)) { if (stats && val.tensor->IsInitialized()) { nodestats::SetOutput(stats, i, val.tensor); } if (val.is_ref()) { out->has_value = true; out->ref = val.tensor; out->ref_mu = val.mutex_if_ref; if (log_memory_) { Tensor to_log; { // Dereference the tensor under the lock. // 在线程锁下解引用这个 tensor mutex_lock l(*out->ref_mu); to_log = *out->ref; } LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, to_log); } // Experimental: debugger (tfdb) access to intermediate node outputs. if (impl_->params_.node_outputs_cb != nullptr) { s.Update(impl_->params_.node_outputs_cb(item.node->name(), i, out->ref, true, ctx)); } } else { // NOTE that std::move is used here, so val.tensor goes to // uninitialized state (val.tensor->IsInitialized return false). // 注意,这里使用 std::move, 所以 val.tensor 转到未初始化的状态( val.tensor->IsInitialized 返回 false)。 DCHECK(!out->val_field_is_set); out->has_value = true; out->val_field_is_set = true; out->val.Init(std::move(*val.tensor)); if (log_memory_) { LogMemory::RecordTensorOutput(ctx->op_kernel().name(), ctx->step_id(), i, *out->val); } // Experimental: debugger access to intermediate node outputs. if (impl_->params_.node_outputs_cb != nullptr) { s.Update(impl_->params_.node_outputs_cb( item.node->name(), i, out->val.get(), false, ctx)); } } } else { s.Update(errors::Internal("Output ", i, " of type ", DataTypeString(dtype), " does not match declared output type ", DataTypeString(item.output_type(i)), " for node ", SummarizeNodeDef(node->def()))); } } if (!val.is_ref()) { // If OpKernelContext returns outputs via pass-by-value, we don't need this trouble. // 如果 OpKernelContext 通过 pass-by-value 返回输出,我们不需要这个 val.tensor。 delete val.tensor; } } return s; } void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, EntryVector* outputs, TaggedNodeSeq* ready) { const Node* node = tagged_node.node; FrameState* input_frame = tagged_node.input_frame; int64 input_iter = tagged_node.input_iter; const bool is_dead = tagged_node.is_dead; // Propagates outputs along out edges, and puts newly ready nodes into the ready queue. // 沿着输出边传播输出,并将新准备好的节点放入就绪队列。 ready->clear(); bool is_frame_done = false; FrameState* output_frame = input_frame; int64 output_iter = input_iter; if (!item->is_enter_exit_or_next_iter) { // Fast path for nodes types that don't need special handling // 不需要特殊处理的节点类型的快速路径 DCHECK_EQ(input_frame, output_frame); // Normal path for most nodes 大多数节点的正常路径 mutex_lock l(input_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } else if (item->is_enter) { bool is_constant; Status s = GetNodeAttr(node->def(), "is_constant", &is_constant); DCHECK(s.ok()) << s; FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); output_iter = 0; { const NodeItem* item = impl_->gview_.node(node->id()); mutex_lock l(output_frame->mu); if (is_constant) { // Propagate to all active iterations if this is a loop invariant. // 如果这是循环不变,则传播到所有活动迭代 output_frame->AddLoopInv(item, (*outputs)[0], ready); } else { output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } output_frame->num_pending_inputs--; } is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } else if (item->is_exit) { if (is_dead) { mutex_lock l(input_frame->mu); // Stop and remember this node if it is a dead exit. // 如果它是一个死亡的出口,停止并记住这个节点。 if (input_iter == input_frame->iteration_count) { input_frame->dead_exits.push_back(node); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } else { output_frame = input_frame->parent_frame; output_iter = input_frame->parent_iter; { mutex_lock l(output_frame->mu); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } } else { DCHECK(IsNextIteration(node)); mutex_lock l(input_frame->mu); if (is_dead) { // Stop the deadness propagation. // 停止死亡传播。 output_frame = nullptr; } else { if (input_iter == input_frame->iteration_count && input_frame->num_outstanding_iterations == input_frame->max_parallel_iterations) { // Reached the maximum for parallel iterations. // 达到并行迭代的最大值。 input_frame->next_iter_roots.push_back({node, (*outputs)[0]}); output_frame = nullptr; } else { // If this is a new iteration, start it. // 如果这是一个新的迭代,启动它。 if (input_iter == input_frame->iteration_count) { input_frame->IncrementIteration(&impl_->gview_, ready); } output_iter = input_iter + 1; } } if (output_frame != nullptr) { // This is the case when node is not Enter, Exit, or NextIteration. // 当节点不是 Enter, Exit 或 NextIteration 时是这种情况。 DCHECK(input_frame == output_frame); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } // At this point, this node is completely done. We also know if the // completion of this node makes its frame completed. // 此时,这个节点已经完成了。 我们还知道这个节点的完成是否使其 frame 完成。 if (is_frame_done) { FrameState* parent_frame = input_frame->parent_frame; int64 parent_iter = input_frame->parent_iter; DeleteFrame(input_frame, ready); if (parent_frame != nullptr) { // The completion of frame may cause completions in its parent frame. // So clean things up recursively. // 帧的完成可能导致其父帧中的完成。所以清理的东西递归。 CleanupFramesIterations(parent_frame, parent_iter, ready); } } } bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready) { if (stats) { nodestats::SetAllEnd(stats); if (!SetTimelineLabel(node, stats)) { // Only record non-transfer nodes. 只记录非传输节点。 stats_collector_->Save(impl_->params_.device->name(), stats); } else { delete stats; } } bool abort_run = false; if (!s.ok()) { // Some error happened. This thread of computation is done. // 发生了一些错误。 这个计算线程完成了。 mutex_lock l(mu_); if (status_.ok()) { abort_run = true; status_ = s; } } if (abort_run) { TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); if (rendezvous_) { rendezvous_->StartAbort(s); } if (cancellation_manager_) { cancellation_manager_->StartCancel(); } } bool completed = false; int ready_size = ready.size(); if (ready_size == 0 || !s.ok()) { completed = (num_outstanding_ops_.fetch_sub(1) == 1); } else if (ready_size > 1) { num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); } // Schedule the ready nodes in 'ready'. // 在'ready'中安排就绪节点。 if (s.ok()) { ScheduleReady(ready, inline_ready); } return completed; } void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, TaggedNodeReadyQueue* inline_ready) { if (ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } if (inline_ready == nullptr) { // Schedule to run all the ready ops in thread pool. // 计划运行在线程池中所有准备好的操作。 for (auto& tagged_node : ready) { runner_([=]() { Process(tagged_node, scheduled_usec); }); } return; } const GraphView& gview = impl_->gview_; const TaggedNode* curr_expensive_node = nullptr; for (auto& tagged_node : ready) { const NodeItem& item = *gview.node(tagged_node.node->id()); if (tagged_node.is_dead || !item.kernel_is_expensive) { // Inline this inexpensive node. inline_ready->push_back(tagged_node); } else { if (curr_expensive_node) { // Dispatch to another thread since there is plenty of work to do for this thread. // 调度到另一个线程,因为这个线程有很多工作要做。 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, scheduled_usec)); } curr_expensive_node = &tagged_node; } } if (curr_expensive_node) { if (inline_ready->empty()) { // Tail recursion optimization // 尾部递归优化 inline_ready->push_back(*curr_expensive_node); } else { // There are inline nodes to run already. We dispatch this expensive node to other thread. // 有 inline nodes 已经可以运行。 我们将这个 expensive node 发送到其他线程。 runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, scheduled_usec)); } } } inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter, int64 node_id) { // TODO(misard) Replace with a finer-grain enabling flag once we // add better optional debugging support. if (vlog_ && VLOG_IS_ON(1)) { const NodeItem* item = impl_->gview_.node(node_id); mutex_lock l(frame->mu); frame->GetIteration(iter)->mark_completed(item->pending_id); } } const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) { if (!input.has_value) { return kEmptyTensor; } else if (input.ref == nullptr) { return input.val.get(); } else { return input.ref; } } void ExecutorState::DumpPendingNodeState( const int node_id, const Entry* input_vector, const bool show_nodes_with_no_ready_inputs) { const NodeItem& node_item = *impl_->gview_.node(node_id); const Node& node = *node_item.node; const int input_base = node_item.input_start; if (!show_nodes_with_no_ready_inputs) { bool has_ready_input = false; for (int i = 0; i < node.num_inputs(); ++i) { const Entry& input = input_vector[input_base + i]; const Tensor* tensor = GetTensorValueForDump(input); if (tensor->IsInitialized()) { has_ready_input = true; break; } } if (!has_ready_input) { return; } } LOG(WARNING) << " Pending Node: " << node.DebugString(); for (int i = 0; i < node.num_inputs(); ++i) { const Entry& input = input_vector[input_base + i]; const Tensor* tensor = GetTensorValueForDump(input); if (tensor->IsInitialized()) { LOG(WARNING) << " Input " << i << ": " << strings::StrCat( "Tensor<type: ", DataTypeString(tensor->dtype()), " shape: ", tensor->shape().DebugString(), ">"); } else { LOG(WARNING) << " Input " << i << ": not present"; } } } void ExecutorState::DumpActiveNodeState(const int node_id, const Entry* input_vector) { const NodeItem& node_item = *impl_->gview_.node(node_id); const Node& node = *node_item.node; LOG(WARNING) << " Active Node: " << node.DebugString(); const int input_base = node_item.input_start; for (int i = 0; i < node.num_inputs(); ++i) { const Entry& input = input_vector[input_base + i]; const Tensor* tensor = GetTensorValueForDump(input); if (tensor->IsInitialized()) { LOG(WARNING) << " Input " << i << ": " << strings::StrCat( "Tensor<type: ", DataTypeString(tensor->dtype()), " shape: ", tensor->shape().DebugString(), ">"); } else { LOG(WARNING) << " Input " << i << ": not present"; } } } void ExecutorState::DumpIterationState(const FrameState* frame, IterationState* iteration) { const std::vector<const Node*>* nodes = frame->nodes; // Dump any waiting nodes that are holding on to tensors. for (const Node* node : *nodes) { int node_id = node->id(); PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id; if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { DumpPendingNodeState(node_id, iteration->input_tensors, false); } } // Then the active nodes. for (const Node* node : *nodes) { int node_id = node->id(); PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id; if (iteration->node_state(pending_id) == PendingCounts::STARTED) { DumpActiveNodeState(node_id, iteration->input_tensors); } } // Show all input tensors in use. int total_input_tensors = frame->total_input_tensors; size_t total_bytes = 0; for (int i = 0; i < total_input_tensors; ++i) { const Entry& input = iteration->input_tensors[i]; const Tensor* tensor = GetTensorValueForDump(input); if (tensor->IsInitialized()) { LOG(WARNING) << " Input " << i << ": " << strings::StrCat( "Tensor<type: ", DataTypeString(tensor->dtype()), " shape: ", tensor->shape().DebugString(), ", bytes: ", tensor->TotalBytes(), ">"); total_bytes += tensor->TotalBytes(); } } LOG(WARNING) << " Total bytes " << total_bytes; } void ExecutorState::DumpState() { mutex_lock l(mu_); if (!dumped_on_error_) { LOG(WARNING) << "Dumping state"; for (auto& frame : outstanding_frames_) { LOG(WARNING) << frame.first; FrameState* frame_state = frame.second; mutex_lock frame_lock(frame_state->mu); for (IterationState* iteration : frame_state->iterations) { LOG(WARNING) << " Iteration:"; DumpIterationState(frame_state, iteration); } } dumped_on_error_ = true; } } void ExecutorState::Finish() { mu_.lock(); auto status = status_; auto done_cb = std::move(done_cb_); auto runner = std::move(runner_); mu_.unlock(); if (sync_on_finish_ && status.ok()) { // Block until the device has finished all queued operations. // For devices like GPUs that continue to execute Ops after their Compute methods // have completed, this ensures that control is not returned to the user until the step // (and its side-effects) has actually completed. // 阻止直到设备完成所有排队操作。 // 对于类似于 GPU 的设备,在其 Compute 方法完成后,继续执行 Ops, // 这确保了在步骤(及其副作用)实际完成之前控制不会返回给用户。 status = impl_->params_.device->Sync(); } delete this; CHECK(done_cb != nullptr); runner([=]() { done_cb(status); }); } void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child) { // Get the child frame name. string enter_name; Status s = GetNodeAttr(node->def(), "frame_name", &enter_name); DCHECK(s.ok()) << s; const string child_name = MakeFrameName(frame, iter, enter_name); { mutex_lock executor_lock(mu_); auto it = outstanding_frames_.find(child_name); if (it != outstanding_frames_.end()) { *child = it->second; return; } } // Need to create a new frame instance. // Note that this new frame instance is created without any locks. if (vlog_) VLOG(2) << "Create frame: " << child_name; int parallel_iters; s = GetNodeAttr(node->def(), "parallel_iterations", ¶llel_iters); DCHECK(s.ok()) << s; FrameState* temp = new FrameState(impl_, parallel_iters); temp->frame_name = child_name; temp->frame_id = Hash64(child_name); temp->parent_frame = frame; temp->parent_iter = iter; temp->InitializeFrameInfo(enter_name); // 'iterations' is a fixed-length circular buffer. temp->iterations.resize(temp->max_parallel_iterations + 1); // Initialize iteration 0. temp->iterations[0] = new IterationState(temp->pending_counts, temp->total_input_tensors); { mutex_lock executor_lock(mu_); auto it = outstanding_frames_.find(child_name); if (it != outstanding_frames_.end()) { *child = it->second; } else { mutex_lock frame_lock(frame->mu); frame->GetIteration(iter)->outstanding_frame_count++; outstanding_frames_[child_name] = temp; *child = temp; temp = nullptr; } } delete temp; // Not used so delete it. } void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { // First, propagate dead_exits (if any) to the parent frame. FrameState* parent_frame = frame->parent_frame; int64 parent_iter = frame->parent_iter; if (parent_frame != nullptr) { mutex_lock paranet_frame_lock(parent_frame->mu); // Propagate all the dead exits to the parent frame. for (const Node* node : frame->dead_exits) { auto parent_iter_state = parent_frame->GetIteration(parent_iter); for (const Edge* e : node->out_edges()) { const Node* dst_node = e->dst(); auto dst_pending_id = impl_->gview_.node(dst_node->id())->pending_id; // TODO(yuanbyu): We don't need this if we require the subgraph // given to an executor not to contain a sink node. if (dst_node->IsSink()) continue; bool dst_dead = true; bool dst_ready = false; // We know this is a dead input to dst. if (IsMerge(dst_node)) { if (e->IsControlEdge()) { parent_iter_state->decrement_pending(dst_pending_id, 2); int count = parent_iter_state->pending(dst_pending_id); int dead_cnt = parent_iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_node->num_inputs()); dst_ready = (count == 0) || ((count == 1) && dst_dead); } else { parent_iter_state->increment_dead_count(dst_pending_id); const int dead_cnt = parent_iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_node->num_inputs()); dst_ready = (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead; } } else { parent_iter_state->increment_dead_count(dst_pending_id); dst_ready = (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0); } if (dst_ready) { if (IsControlTrigger(dst_node)) dst_dead = false; ready->push_back(TaggedNode(dst_node, parent_frame, parent_iter, dst_dead)); parent_iter_state->outstanding_ops++; } } } } // Delete the frame. const string& frame_name = frame->frame_name; if (vlog_) VLOG(2) << "Delete frame " << frame_name; { mutex_lock executor_lock(mu_); outstanding_frames_.erase(frame_name); } delete frame; } void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready) { bool is_frame_done = false; { mutex_lock frame_lock(frame->mu); frame->GetIteration(iter)->outstanding_frame_count--; is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready); } if (is_frame_done) { FrameState* parent_frame = frame->parent_frame; int64 parent_iter = frame->parent_iter; DeleteFrame(frame, ready); if (parent_frame != nullptr) { // The completion of frame may cause completions in its parent frame. // So clean things up recursively. CleanupFramesIterations(parent_frame, parent_iter, ready); } } } void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, EntryVector* outputs, TaggedNodeSeq* ready) { const GraphView& gview = executor->gview_; IterationState* iter_state = GetIteration(iter); const int num_output_edges = item->num_output_edges; const EdgeInfo* edges = item->output_edge_list(); Entry* input_tensors = iter_state->input_tensors; for (int out_index = 0; out_index < num_output_edges; out_index++) { const EdgeInfo& e = edges[out_index]; const int dst_id = e.dst_id; const NodeItem* dst_item = gview.node(dst_id); const PendingCounts::Handle dst_pending_id = dst_item->pending_id; const int src_slot = e.output_slot; // TODO(yuanbyu): We don't need this if we require the subgraph // given to an executor not to contain a sink node. if (dst_item->is_sink) continue; bool dst_dead = false; bool dst_ready = false; // True iff this input for dst is needed. We only set this input for dst if this flag // is true. This is needed to make the thread safety analysis happy. const bool is_control_edge = (src_slot == Graph::kControlSlot); bool dst_need_input = !is_control_edge; if (dst_item->is_merge) { // A merge node is ready if all control inputs have arrived and either // a) a live data input becomes available or b) all data inputs are // dead. For Merge, pending's LSB is set iff a live data input has arrived. if (is_control_edge) { iter_state->decrement_pending(dst_pending_id, 2); int count = iter_state->pending(dst_pending_id); int dead_cnt = iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_item->num_inputs); dst_ready = (count == 0) || ((count == 1) && dst_dead); } else { if ((*outputs)[src_slot].has_value) { // This is a live data input. int count = iter_state->pending(dst_pending_id); iter_state->mark_live(dst_pending_id); // Only the first live edge sets the input and (potentially) triggers execution. // The low bit of count is set if and only if no live input has been used yet // (mark_live clears it). The node should be started if and only if this is // the first live input and there are no pending control edges, i.e. count == 1. dst_ready = (count == 1); dst_need_input = ((count & 0x1) == 1); } else { // This is a dead data input. Note that dst_node is dead if node is a dead enter. // We need this to handle properly a while loop on the untaken branch of a conditional. // TODO(yuanbyu): This is a bit hacky, but a good solution for now. iter_state->increment_dead_count(dst_pending_id); const int dead_cnt = iter_state->dead_count(dst_pending_id); dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter; dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead; dst_need_input = false; } } } else { bool increment_dead = (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value)); int pending, dead; iter_state->adjust_for_activation(dst_pending_id, increment_dead, &pending, &dead); dst_dead = (dead > 0); dst_ready = (pending == 0); } if (dst_need_input) { const int dst_slot = e.input_slot; const int dst_loc = dst_item->input_start + dst_slot; if (e.is_last) { input_tensors[dst_loc] = std::move((*outputs)[src_slot]); } else { input_tensors[dst_loc] = (*outputs)[src_slot]; } } // Add dst to the ready queue if it's ready if (dst_ready) { if (dst_item->is_control_trigger) dst_dead = false; ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead)); iter_state->outstanding_ops++; } } } void ExecutorState::FrameState::ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) { // Propagate the deferred NextIteration nodes to the new iteration. for (auto& node_entry : next_iter_roots) { const Node* node = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = !entry.has_value; const NodeItem* item = gview->node(node->id()); EntryVector outputs{entry}; ActivateNodes(item, is_dead, iter, &outputs, ready); } next_iter_roots.clear(); } void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) { // Propagate loop invariants to the new iteration. for (auto& node_entry : inv_values) { const Node* node = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = !entry.has_value; const NodeItem* item = gview->node(node->id()); EntryVector outputs{entry}; ActivateNodes(item, is_dead, iter, &outputs, ready); } } void ExecutorState::FrameState::AddLoopInv(const NodeItem* item, const Entry& entry, TaggedNodeSeq* ready) { // Store this value. inv_values.push_back({item->node, entry}); // Make this value available to all iterations. bool is_dead = !entry.has_value; for (int i = 0; i <= iteration_count; ++i) { EntryVector outputs{entry}; ActivateNodes(item, is_dead, i, &outputs, ready); } } bool ExecutorState::FrameState::IsIterationDone(int64 iter) { IterationState* iter_state = GetIteration(iter); if (iter_state->outstanding_ops == 0 && iter_state->outstanding_frame_count == 0) { if (iter == 0) { // The enclosing frame has no pending input. return num_pending_inputs == 0; } else { // The preceding iteration is deleted (and therefore done). return (GetIteration(iter - 1) == nullptr); } } return false; } void ExecutorState::FrameState::IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready) { iteration_count++; int64 next_iter = iteration_count; // Initialize the next iteration. IterationState* iter_state = new IterationState(pending_counts, total_input_tensors); SetIteration(next_iter, iter_state); num_outstanding_iterations++; dead_exits.clear(); // Activate the successors of the deferred roots in the new iteration. ActivateNexts(gview, next_iter, ready); // Activate the loop invariants in the new iteration. ActivateLoopInvs(gview, next_iter, ready); } bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview, int64 iter, TaggedNodeSeq* ready) { int64 curr_iter = iter; while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { // Delete the iteration curr_iter. delete GetIteration(curr_iter); SetIteration(curr_iter, nullptr); --num_outstanding_iterations; ++curr_iter; // When one iteration is completed, we check for deferred iteration, // and start it if there is one. if (!next_iter_roots.empty()) { IncrementIteration(gview, ready); } } return IsFrameDone(); } // directsession.cc // item.executor->RunAsync(args, barrier->Get()); // RunAsync 开始异步运行 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { (new ExecutorState(args, this))->RunAsync(done); } } // end namespace Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, Executor** executor) { ExecutorImpl* impl = new ExecutorImpl(params, graph); Status s = impl->Initialize(); if (s.ok()) { *executor = impl; } else { delete impl; } return s; } Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, int graph_def_version, OpKernel** kernel) { auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, ndef, graph_def_version, kernel); } void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } } // end namespace tensorflow

转载请注明原文地址: https://www.6miu.com/read-22007.html

最新回复(0)