|
1 | 1 | use itertools::Itertools; |
2 | | -use ndarray::{prelude::*, IntoDimension, Slice}; |
| 2 | +use ndarray::{prelude::*, IntoDimension, Order, Slice}; |
3 | 3 |
|
4 | 4 | use crate::{ |
5 | 5 | util::dim_from_vec, |
6 | 6 | CsapsError::{ReshapeFrom2d, ReshapeTo2d}, |
7 | 7 | Real, Result, |
8 | 8 | }; |
9 | 9 |
|
| 10 | +pub(crate) fn reshape_order<T, D>(data: &ArrayView<'_, T, D>) -> Order |
| 11 | +where |
| 12 | + D: Dimension, |
| 13 | +{ |
| 14 | + if data.is_standard_layout() { |
| 15 | + Order::RowMajor |
| 16 | + } else if data.ndim() > 1 && data.raw_view().reversed_axes().is_standard_layout() { |
| 17 | + Order::ColumnMajor |
| 18 | + } else { |
| 19 | + Order::RowMajor |
| 20 | + } |
| 21 | +} |
| 22 | + |
10 | 23 | pub fn diff<'a, T: 'a, D, V>(data: V, axis: Option<Axis>) -> Array<T, D> |
11 | 24 | where |
12 | 25 | T: Real<T>, |
|
43 | 56 | let axis_size = shape[axis.0]; |
44 | 57 | let new_shape = [numel / axis_size, axis_size]; |
45 | 58 |
|
46 | | - match data_view.permuted_axes(axes).into_shape(new_shape) { |
| 59 | + let data_view = data_view.permuted_axes(axes); |
| 60 | + let order = reshape_order(&data_view); |
| 61 | + |
| 62 | + match data_view.into_shape_with_order((new_shape, order)) { |
47 | 63 | Ok(view_2d) => Ok(view_2d), |
48 | 64 | Err(error) => Err(ReshapeTo2d { |
49 | 65 | input_shape: shape, |
|
62 | 78 | let shape = data.shape().to_vec(); |
63 | 79 | let new_shape = [shape[0..(ndim - 1)].iter().product(), shape[ndim - 1]]; |
64 | 80 |
|
65 | | - match data.into_shape(new_shape) { |
| 81 | + let order = reshape_order(&data); |
| 82 | + |
| 83 | + match data.into_shape_with_order((new_shape, order)) { |
66 | 84 | Ok(data_2d) => Ok(data_2d), |
67 | 85 | Err(error) => Err(ReshapeTo2d { |
68 | 86 | input_shape: shape, |
|
93 | 111 | let new_shape: D = dim_from_vec(ndim, new_shape_vec.clone()); |
94 | 112 | let data_view = data.into(); |
95 | 113 |
|
96 | | - match data_view.into_shape(new_shape) { |
| 114 | + let order = reshape_order(&data_view); |
| 115 | + |
| 116 | + match data_view.into_shape_with_order((new_shape, order)) { |
97 | 117 | Ok(view_nd) => { |
98 | 118 | let mut axes_tmp: Vec<usize> = (0..ndim).collect(); |
99 | 119 | let end_axis = axes_tmp.pop().unwrap(); |
|
0 commit comments