Commit 4e131d27 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TensorFlower Gardener
Browse files

Many algorithms need to enumerate the set of nodes within a graph, while...

Many algorithms need to enumerate the set of nodes within a graph, while excluding the special Sink and Source nodes.  The checks for skipping Source and Sink are duplicated in dozens of loops.

This CL adds a new Graph::op_nodes() method, which returns an enumerable range of all operation nodes, excluding Sink and Source.  This allows many for loops to be simplified.

This simplification is being done mainly for readability / reliability.  There may be a tiny performance difference owing to this change (as well as making the Graph::nodes() and Graph::op_nodes() methods inlineable), but the measured difference is not reliably large enough to be significant.

The changes to graph.h and graph.cc are quite minimal.  I updated all of the uses of Graph::nodes() that I could reliably determine were unaffected by the change.  Most uses immediately checked node->IsOp().  Some compared node->type_string() against literal strings, none of which were "_SINK" or "_SOURCE", and so using op_nodes() was more appropriate than nodes().  In some cases, it was not obvious whether an existing use of Graph::node() wanted to enumerate Sink / Source, so I left those uses unaffected.

PiperOrigin-RevId: 156782112
parent 89e09f63
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -200,7 +200,7 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
  for (const Fetch& fetch : config.fetch()) {
    missing_fetches.insert(TensorIdToString(fetch.id()));
  }
  for (const Node* n : graph->nodes()) {
  for (const Node* n : graph->op_nodes()) {
    if (n->type_string() == kArgOp) {
      string feed_id;
      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
+2 −2
Original line number Diff line number Diff line
@@ -125,9 +125,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
  Graph* graph = options.graph->get();

  for (Node* n : graph->nodes()) {
  for (Node* n : graph->op_nodes()) {
    // In all cases, only try to compile computational nodes.
    if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
    if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
      continue;
    }

+4 −6
Original line number Diff line number Diff line
@@ -178,8 +178,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
  std::unordered_map<Node*, Node*> node_images;

  // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
  for (Node* node : graph_in_->nodes()) {
    if (node->IsSource() || node->IsSink()) continue;
  for (Node* node : graph_in_->op_nodes()) {
    string func_id = GetFunctionNameAttr(node);
    if (func_id.empty()) continue;

@@ -445,8 +444,7 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
  std::unordered_map<const Node*, Node*> node_images;

  // Copy all unmarked nodes to the output graph.
  for (Node* node : graph_in_->nodes()) {
    if (node->IsSource() || node->IsSink()) continue;
  for (Node* node : graph_in_->op_nodes()) {
    string func_id = GetFunctionNameAttr(node);

    // Don't copy nodes that going to be encapsulated, unless parallel checking
@@ -590,7 +588,7 @@ Status EncapsulateSubgraphsInFunctions(

// Finds the types of the _Arg nodes, indexed by position.
static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
  for (Node* n : graph.nodes()) {
  for (Node* n : graph.op_nodes()) {
    if (n->type_string() == kArgOp) {
      int index;
      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
@@ -607,7 +605,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
// 'permutation' that maps old indices to new indices.
static Status RenumberArguments(Graph* graph,
                                const std::vector<int>& permutation) {
  for (Node* n : graph->nodes()) {
  for (Node* n : graph->op_nodes()) {
    if (n->type_string() == kArgOp) {
      int index;
      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+1 −3
Original line number Diff line number Diff line
@@ -120,9 +120,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
  std::unordered_map<string, string> return_values;
  NodeNameMapping node_names;

  for (Node const* node : graph.nodes()) {
    if (!node->IsOp()) continue;

  for (Node const* node : graph.op_nodes()) {
    if (node->type_string() == kArgOp) {
      int index;
      DataType type;
+2 −5
Original line number Diff line number Diff line
@@ -132,8 +132,7 @@ bool IsCompilableCall(const NodeDef& call_def,
    return false;
  }

  for (Node* node : fbody->graph->nodes()) {
    if (node->IsSource() || node->IsSink()) continue;
  for (Node* node : fbody->graph->op_nodes()) {
    if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
      continue;
    if (node->type_string() == "While") {
@@ -176,9 +175,7 @@ Status FindCompilationCandidates(
  std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
      nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));

  for (Node* node : graph.nodes()) {
    if (node->IsSource() || node->IsSink()) continue;

  for (Node* node : graph.op_nodes()) {
    DeviceType device_type("");
    TF_RETURN_IF_ERROR(
        DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
Loading