1use core::mem;
10use core::ptr::copy_nonoverlapping;
11
12use rawpointer::PointerExt;
13
14use crate::kernel::Element;
15use crate::kernel::ConstNum;
16
17#[cfg(feature = "std")]
18macro_rules! fmuladd {
19 (fma_yes, $dst:expr, $a:expr, $b:expr) => {
21 {
22 $dst = $a.mul_add($b, $dst);
23 }
24 };
25 (fma_no, $dst:expr, $a:expr, $b:expr) => {
26 {
27 $dst += $a * $b;
28 }
29 };
30}
31
32#[cfg(not(feature = "std"))]
33macro_rules! fmuladd {
34 ($any:tt, $dst:expr, $a:expr, $b:expr) => {
35 {
36 $dst += $a * $b;
37 }
38 };
39}
40
41
42macro_rules! kernel_fallback_impl_complex {
46 ([$($attr:meta)*] [$fma_opt:tt] $name:ident, $elem_ty:ty, $real_ty:ty, $mr:expr, $nr:expr, $unroll:tt) => {
47 $(#[$attr])*
48 unsafe fn $name(k: usize, alpha: $elem_ty, a: *const $elem_ty, b: *const $elem_ty,
49 beta: $elem_ty, c: *mut $elem_ty, rsc: isize, csc: isize)
50 {
51 const MR: usize = $mr;
52 const NR: usize = $nr;
53
54 debug_assert_eq!(beta, <$elem_ty>::zero(), "Beta must be 0 or is not masked");
55
56 let mut pp = [<$real_ty>::zero(); MR];
57 let mut qq = [<$real_ty>::zero(); MR];
58 let mut rr = [<$real_ty>::zero(); NR];
59 let mut ss = [<$real_ty>::zero(); NR];
60
61 let mut ab: [[$elem_ty; NR]; MR] = [[<$elem_ty>::zero(); NR]; MR];
62 let mut areal = a as *const $real_ty;
63 let mut breal = b as *const $real_ty;
64
65 unroll_by!($unroll => k, {
66 let aimag = areal.add(MR);
72 let bimag = breal.add(NR);
73
74 loop_m!(i, {
76 pp[i] = at(areal, i);
77 qq[i] = at(aimag, i);
78 });
79 loop_n!(j, {
80 rr[j] = at(breal, j);
81 ss[j] = at(bimag, j);
82 });
83 loop_m!(i, {
84 loop_n!(j, {
85 fmuladd!($fma_opt, ab[i][j][0], pp[i], rr[j]);
87 fmuladd!($fma_opt, ab[i][j][1], pp[i], ss[j]);
88 fmuladd!($fma_opt, ab[i][j][0], -qq[i], ss[j]);
89 fmuladd!($fma_opt, ab[i][j][1], qq[i], rr[j]);
90 })
91 });
92
93 areal = aimag.add(MR);
94 breal = bimag.add(NR);
95 });
96
97 macro_rules! c {
98 ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
99 }
100
101 loop_n!(j, loop_m!(i, *c![i, j] = mul(alpha, ab[i][j])));
103 }
104 };
105}
106
107macro_rules! pack_methods {
109 () => {
110 #[inline]
111 unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
112 a: *const Self::Elem, rsa: isize, csa: isize)
113 {
114 pack_complex::<Self::MRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
115 }
116
117 #[inline]
118 unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
119 a: *const Self::Elem, rsa: isize, csa: isize)
120 {
121 pack_complex::<Self::NRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
122 }
123 }
124}
125
126
127pub(crate) unsafe fn pack_complex<MR, T, TReal>(kc: usize, mc: usize, pack: &mut [T],
141 a: *const T, rsa: isize, csa: isize)
142 where MR: ConstNum,
143 T: Element,
144 TReal: Element,
145{
146 let pack = pack.as_mut_ptr() as *mut TReal;
148 let areal = a as *const TReal;
149 let aimag = areal.add(1);
150
151 assert_eq!(mem::size_of::<T>(), 2 * mem::size_of::<TReal>());
152
153 let mr = MR::VALUE;
154 let mut p = 0; for ir in 0..mc/mr {
158 let row_offset = ir * mr;
159 for j in 0..kc {
160 for i in 0..mr {
162 let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
163 .stride_offset(2 * csa, j);
164 copy_nonoverlapping(a_elt, pack.add(p), 1);
165 p += 1;
166 }
167 for i in 0..mr {
169 let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
170 .stride_offset(2 * csa, j);
171 copy_nonoverlapping(a_elt, pack.add(p), 1);
172 p += 1;
173 }
174 }
175 }
176
177 let zero = TReal::zero();
178
179 let rest = mc % mr;
181 if rest > 0 {
182 let row_offset = (mc/mr) * mr;
183 for j in 0..kc {
184 for i in 0..mr {
186 if i < rest {
187 let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
188 .stride_offset(2 * csa, j);
189 copy_nonoverlapping(a_elt, pack.add(p), 1);
190 } else {
191 *pack.add(p) = zero;
192 }
193 p += 1;
194 }
195 for i in 0..mr {
197 if i < rest {
198 let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
199 .stride_offset(2 * csa, j);
200 copy_nonoverlapping(a_elt, pack.add(p), 1);
201 } else {
202 *pack.add(p) = zero;
203 }
204 p += 1;
205 }
206 }
207 }
208}