Sorting a 3D numpy array using numpy.argsort throws out a puzzling result

Suppose I have the following array:

x = array([[[2, 5],
            [6, 7]],

           [[8, 1],
            [4, 9]]])

I want to sort it across the first column of each sub-arrays to get the following outcome:

out = array([[[2, 5],
              [6, 7]],

             [[4, 9],
              [8, 1]]])

So when run the following code:

x[:,x[:,:,0].argsort()]

the outcome is:

out1 = array([[[[2, 5],
                [6, 7]],

               [[6, 7],
                [2, 5]]],

              [[[8, 1],
                [4, 9]],

               [[4, 9],
                [8, 1]]]])

It turns out my desired outcome is on diagonal of this 2×2 matrix, so I can still get my desired outcome but I don’t understand what the off-diagonals are. It doesn’t even look like they are sorted in any way.

Where did the off-diagonal arrays come from?

Also, how can I get my desired outcome without having to go through this fairly large and useless array (out1)?

Answer

You need to use advanced indexing instead of using slice:

x[np.arange(len(x))[:,None], x[...,0].argsort()]

#[[[2 5]
#  [6 7]]

# [[4 9]
#  [8 1]]]