matrixmultiply/
cgemm_common.rs

1// Copyright 2021-2023 Ulrik Sverdrup "bluss"
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use core::mem;
10use core::ptr::copy_nonoverlapping;
11
12use rawpointer::PointerExt;
13
14use crate::kernel::Element;
15use crate::kernel::ConstNum;
16
17#[cfg(feature = "std")]
18macro_rules! fmuladd {
19    // conceptually $dst += $a * $b, optionally use fused multiply-add
20    (fma_yes, $dst:expr, $a:expr, $b:expr) => {
21        {
22            $dst = $a.mul_add($b, $dst);
23        }
24    };
25    (fma_no, $dst:expr, $a:expr, $b:expr) => {
26        {
27            $dst += $a * $b;
28        }
29    };
30}
31
32#[cfg(not(feature = "std"))]
33macro_rules! fmuladd {
34    ($any:tt, $dst:expr, $a:expr, $b:expr) => {
35        {
36            $dst += $a * $b;
37        }
38    };
39}
40
41
42// kernel fallback impl macro
43// Depends on a couple of macro and function defitions to be in scope - loop_m/_n, at, etc.
44// $fma_opt: fma_yes or fma_no to use f32::mul_add etc or not
45macro_rules! kernel_fallback_impl_complex {
46    ([$($attr:meta)*] [$fma_opt:tt] $name:ident, $elem_ty:ty, $real_ty:ty, $mr:expr, $nr:expr, $unroll:tt) => {
47    $(#[$attr])*
48    unsafe fn $name(k: usize, alpha: $elem_ty, a: *const $elem_ty, b: *const $elem_ty,
49                    beta: $elem_ty, c: *mut $elem_ty, rsc: isize, csc: isize)
50    {
51        const MR: usize = $mr;
52        const NR: usize = $nr;
53
54        debug_assert_eq!(beta, <$elem_ty>::zero(), "Beta must be 0 or is not masked");
55
56        let mut pp  = [<$real_ty>::zero(); MR];
57        let mut qq  = [<$real_ty>::zero(); MR];
58        let mut rr  = [<$real_ty>::zero(); NR];
59        let mut ss  = [<$real_ty>::zero(); NR];
60
61        let mut ab: [[$elem_ty; NR]; MR] = [[<$elem_ty>::zero(); NR]; MR];
62        let mut areal = a as *const $real_ty;
63        let mut breal = b as *const $real_ty;
64
65        unroll_by!($unroll => k, {
66            // We set:
67            // P + Q i = A
68            // R + S i = B
69            //
70            // see pack_complex for how data is packed
71            let aimag = areal.add(MR);
72            let bimag = breal.add(NR);
73
74            // AB = PR - QS + i (QR + PS)
75            loop_m!(i, {
76                pp[i] = at(areal, i);
77                qq[i] = at(aimag, i);
78            });
79            loop_n!(j, {
80                rr[j] = at(breal, j);
81                ss[j] = at(bimag, j);
82            });
83            loop_m!(i, {
84                loop_n!(j, {
85                    // optionally use fma
86                    fmuladd!($fma_opt, ab[i][j][0], pp[i], rr[j]);
87                    fmuladd!($fma_opt, ab[i][j][1], pp[i], ss[j]);
88                    fmuladd!($fma_opt, ab[i][j][0], -qq[i], ss[j]);
89                    fmuladd!($fma_opt, ab[i][j][1], qq[i], rr[j]);
90                })
91            });
92
93            areal = aimag.add(MR);
94            breal = bimag.add(NR);
95        });
96
97        macro_rules! c {
98            ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
99        }
100
101        // set C = α A B
102        loop_n!(j, loop_m!(i, *c![i, j] = mul(alpha, ab[i][j])));
103    }
104    };
105}
106
107/// GemmKernel packing trait methods
108macro_rules! pack_methods {
109    () => {
110        #[inline]
111        unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
112                          a: *const Self::Elem, rsa: isize, csa: isize)
113        {
114            pack_complex::<Self::MRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
115        }
116
117        #[inline]
118        unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
119                        a: *const Self::Elem, rsa: isize, csa: isize)
120        {
121            pack_complex::<Self::NRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
122        }
123    }
124}
125
126
127/// Pack complex: similar to general packing but separate rows for real and imag parts.
128///
129/// Source matrix contains [p0 + q0i, p1 + q1i, p2 + q2i, ..] and it's packed into
130/// alternate rows of real and imaginary parts.
131///
132/// [ p0 p1 p2 p3 .. (MR repeats)
133///   q0 q1 q2 q3 .. (MR repeats)
134///   px p_ p_ p_ .. (x = MR)
135///   qx q_ q_ q_ .. (x = MR)
136///   py p_ p_ p_ .. (y = 2 * MR)
137///   qy q_ q_ q_ .. (y = 2 * MR)
138///   ...
139/// ]
140pub(crate) unsafe fn pack_complex<MR, T, TReal>(kc: usize, mc: usize, pack: &mut [T],
141                                                a: *const T, rsa: isize, csa: isize)
142    where MR: ConstNum,
143          T: Element,
144          TReal: Element,
145{
146    // use pointers as pointer to TReal
147    let pack = pack.as_mut_ptr() as *mut TReal;
148    let areal = a as *const TReal;
149    let aimag = areal.add(1);
150
151    assert_eq!(mem::size_of::<T>(), 2 * mem::size_of::<TReal>());
152
153    let mr = MR::VALUE;
154    let mut p = 0; // offset into pack
155
156    // general layout case (no contig case when stride != 1)
157    for ir in 0..mc/mr {
158        let row_offset = ir * mr;
159        for j in 0..kc {
160            // real row
161            for i in 0..mr {
162                let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
163                                 .stride_offset(2 * csa, j);
164                copy_nonoverlapping(a_elt, pack.add(p), 1);
165                p += 1;
166            }
167            // imag row
168            for i in 0..mr {
169                let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
170                                 .stride_offset(2 * csa, j);
171                copy_nonoverlapping(a_elt, pack.add(p), 1);
172                p += 1;
173            }
174        }
175    }
176
177    let zero = TReal::zero();
178
179    // Pad with zeros to multiple of kernel size (uneven mc)
180    let rest = mc % mr;
181    if rest > 0 {
182        let row_offset = (mc/mr) * mr;
183        for j in 0..kc {
184            // real row
185            for i in 0..mr {
186                if i < rest {
187                    let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
188                                     .stride_offset(2 * csa, j);
189                    copy_nonoverlapping(a_elt, pack.add(p), 1);
190                } else {
191                    *pack.add(p) = zero;
192                }
193                p += 1;
194            }
195            // imag row
196            for i in 0..mr {
197                if i < rest {
198                    let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
199                                     .stride_offset(2 * csa, j);
200                    copy_nonoverlapping(a_elt, pack.add(p), 1);
201                } else {
202                    *pack.add(p) = zero;
203                }
204                p += 1;
205            }
206        }
207    }
208}