Browse Source

Fix bug where span was not correctly considered

Christoph Stelz 4 months ago
parent
commit
6bb7efdd19
1 changed files with 22 additions and 13 deletions
  1. 22 13
      src/ipc_debug.cpp

+ 22 - 13
src/ipc_debug.cpp

@@ -271,27 +271,36 @@ void debug_ipc_mpi_set_data_distribution(int start_index, uint64_t local_length)
 
 
 void debug_ipc_assert_equal_mpi_double_array(double *array, size_t array_length, int span) {
+
     if (rank == 0) {
-        size_t total_length = start_indices_list[cluster_size];
+        std::vector<int> actual_recv_counts(recv_counts);
+        std::vector<int> displacements(recv_counts);
+
+        size_t total_length = span * start_indices_list[cluster_size];
         float_buffer.resize(total_length); // Make sure we have enough space for all values
-        assert(recv_counts.size() == static_cast<unsigned>(cluster_size));
+        assert(actual_recv_counts.size() == static_cast<unsigned>(cluster_size));
         assert(start_indices_map.size() == static_cast<unsigned>(cluster_size + 1));
 
         for (unsigned int i = 0; i < recv_counts.size(); i++) {
-            recv_counts[i] *= span;
+            actual_recv_counts[i] *= span;
         }
-    }
-
-    // Gather all data at the root.
-    MPI_Gatherv(array, array_length, MPI_DOUBLE,
-            float_buffer.data(), recv_counts.data(), start_indices_list.data(),
-            MPI_DOUBLE, 0, MPI_COMM_WORLD);
-
-    if (rank == 0) {
-        debug_ipc_assert_equal_vector(float_buffer);
 
         for (unsigned int i = 0; i < recv_counts.size(); i++) {
-            recv_counts[i] /= span;
+            displacements[i] *= span;
         }
+
+        MPI_Gatherv(array, array_length * span, MPI_DOUBLE,
+                float_buffer.data(), // recvbuf
+                actual_recv_counts.data(),  // recv counts
+                start_indices_list.data(), // displacements
+                MPI_DOUBLE, 0, MPI_COMM_WORLD);
+        debug_ipc_assert_equal_vector(float_buffer);
+    } else {
+        // Gather all data at the root.
+        MPI_Gatherv(array, array_length * span, MPI_DOUBLE,
+                nullptr, // recvbuf
+                nullptr,  // recv counts
+                nullptr, // displacements
+                MPI_DOUBLE, 0, MPI_COMM_WORLD);
     }
 }