Skip to content

Commit 584f581

Browse files
authored
ml-kem, module-lattice: avoid UDIV in compiled output (#289)
When compiling ML-KEM and checking the resulting binary for side-channel leakage, several false positive UDIVs appear on ARM assembly. This is a mild annoyance, but not security-relevant. This PR updates module-lattice and ml-kem to avoid the division operators entirely. As an added bonus, `cargo bench` reports a performance win: | Operation | master | this branch | Δ | p-value | |-------------|----------|----------|---------|---------| | keygen | 31.00 µs | 26.57 µs | −14.08% | < 0.05 | | encapsulate | 27.80 µs | 22.78 µs | −18.80% | < 0.05 | | decapsulate | 34.48 µs | 26.18 µs | −23.30% | < 0.05 | | round_trip | 99.46 µs | 82.35 µs | −17.27% | < 0.05 | ## Raw criterion output ### master (baseline) ``` keygen time: [30.917 µs 31.003 µs 31.115 µs] encapsulate time: [27.637 µs 27.802 µs 28.046 µs] decapsulate time: [34.279 µs 34.479 µs 34.778 µs] round_trip time: [99.161 µs 99.463 µs 99.854 µs] ``` ### ml-kem-undivided (compared against master) ``` keygen time: [26.493 µs 26.574 µs 26.691 µs] change: [−14.429% −14.079% −13.765%] (p = 0.00 < 0.05) Performance has improved. encapsulate time: [22.478 µs 22.781 µs 23.228 µs] change: [−19.472% −18.797% −17.936%] (p = 0.00 < 0.05) Performance has improved. decapsulate time: [26.089 µs 26.185 µs 26.304 µs] change: [−23.952% −23.304% −22.512%] (p = 0.00 < 0.05) Performance has improved. round_trip time: [81.947 µs 82.345 µs 83.057 µs] change: [−17.604% −17.269% −16.761%] (p = 0.00 < 0.05) Performance has improved. ``` ## Claude's Interpretation > [!NOTE] > Take this with a grain of salt, but it does sound plausible. - **NTT const-generic layers** (`ntt_layer<LEN, ITERATIONS>` / `ntt_inverse_layer<LEN, ITERATIONS>`) are the dominant win. With `LEN` and `ITERATIONS` compile-time constants, the inner loops unroll completely and LLVM auto-vectorizes the butterfly into NEON (`add.8h`, `sub.8h`, `cmhs.8h`, `bic.8h`). In the original form, `(0..256).step_by(2 * len)` carried a runtime `UDIV` and blocked unrolling through the outer `for len in [...]`. - **Decapsulate benefits the most (−23%)** because it runs both `ntt` and `ntt_inverse` on the length-`K` vector and also hits the `D = 12` `byte_decode` path. - **Keygen (−14%)** mainly benefits from the forward NTT on the secret and error vectors. - **Encapsulate (−19%)** benefits from the forward NTT on the randomness vector and matrix-vector product in the NTT domain.
1 parent 5b84cfb commit 584f581

3 files changed

Lines changed: 69 additions & 29 deletions

File tree

ml-kem/src/algebra.rs

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -134,26 +134,42 @@ pub(crate) trait Ntt {
134134
fn ntt(&self) -> Self::Output;
135135
}
136136

137+
/// One layer of the forward NTT butterfly.
138+
///
139+
/// `LEN` is the butterfly half-length and `ITERATIONS = 128 / LEN` is the number of
140+
/// butterfly groups in the layer. Making both compile-time constants lets the compiler
141+
/// eliminate the iterator length calculation (`256 / (2 * LEN)`) that `step_by` would
142+
/// otherwise compute with a `UDIV` instruction.
143+
#[inline(always)]
144+
fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(f: &mut Array<Elem, U256>, k: &mut usize) {
145+
for i in 0..ITERATIONS {
146+
let start = i * 2 * LEN;
147+
let zeta = ZETA_POW_BITREV[*k];
148+
*k += 1;
149+
150+
for j in start..(start + LEN) {
151+
let t = zeta * f[j + LEN];
152+
f[j + LEN] = f[j] - t;
153+
f[j] = f[j] + t;
154+
}
155+
}
156+
}
157+
137158
/// Algorithm 9: `NTT`
138159
impl Ntt for Polynomial {
139160
type Output = NttPolynomial;
140161

141162
fn ntt(&self) -> NttPolynomial {
142163
let mut k = 1;
143-
144164
let mut f = self.0;
145-
for len in [128, 64, 32, 16, 8, 4, 2] {
146-
for start in (0..256).step_by(2 * len) {
147-
let zeta = ZETA_POW_BITREV[k];
148-
k += 1;
149-
150-
for j in start..(start + len) {
151-
let t = zeta * f[j + len];
152-
f[j + len] = f[j] - t;
153-
f[j] = f[j] + t;
154-
}
155-
}
156-
}
165+
166+
ntt_layer::<128, 1>(&mut f, &mut k);
167+
ntt_layer::<64, 2>(&mut f, &mut k);
168+
ntt_layer::<32, 4>(&mut f, &mut k);
169+
ntt_layer::<16, 8>(&mut f, &mut k);
170+
ntt_layer::<8, 16>(&mut f, &mut k);
171+
ntt_layer::<4, 32>(&mut f, &mut k);
172+
ntt_layer::<2, 64>(&mut f, &mut k);
157173

158174
f.into()
159175
}
@@ -175,26 +191,42 @@ pub(crate) trait NttInverse {
175191
fn ntt_inverse(&self) -> Self::Output;
176192
}
177193

194+
/// One layer of the inverse NTT butterfly.
195+
///
196+
/// See [`ntt_layer`] for the rationale behind the const generics.
197+
#[inline(always)]
198+
fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
199+
f: &mut Array<Elem, U256>,
200+
k: &mut usize,
201+
) {
202+
for i in 0..ITERATIONS {
203+
let start = i * 2 * LEN;
204+
let zeta = ZETA_POW_BITREV[*k];
205+
*k -= 1;
206+
207+
for j in start..(start + LEN) {
208+
let t = f[j];
209+
f[j] = t + f[j + LEN];
210+
f[j + LEN] = zeta * (f[j + LEN] - t);
211+
}
212+
}
213+
}
214+
178215
/// Algorithm 10: `NTT^{-1}`
179216
impl NttInverse for NttPolynomial {
180217
type Output = Polynomial;
181218

182219
fn ntt_inverse(&self) -> Polynomial {
183220
let mut f: Array<Elem, U256> = self.0.clone();
184-
185221
let mut k = 127;
186-
for len in [2, 4, 8, 16, 32, 64, 128] {
187-
for start in (0..256).step_by(2 * len) {
188-
let zeta = ZETA_POW_BITREV[k];
189-
k -= 1;
190-
191-
for j in start..(start + len) {
192-
let t = f[j];
193-
f[j] = t + f[j + len];
194-
f[j + len] = zeta * (f[j + len] - t);
195-
}
196-
}
197-
}
222+
223+
ntt_inverse_layer::<2, 64>(&mut f, &mut k);
224+
ntt_inverse_layer::<4, 32>(&mut f, &mut k);
225+
ntt_inverse_layer::<8, 16>(&mut f, &mut k);
226+
ntt_inverse_layer::<16, 8>(&mut f, &mut k);
227+
ntt_inverse_layer::<32, 4>(&mut f, &mut k);
228+
ntt_inverse_layer::<64, 2>(&mut f, &mut k);
229+
ntt_inverse_layer::<128, 1>(&mut f, &mut k);
198230

199231
Elem::new(3303) * &Polynomial::new(f)
200232
}

module-lattice/src/algebra.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ macro_rules! define_field {
7272
const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
7373

7474
fn small_reduce(x: Self::Int) -> Self::Int {
75-
if x < Self::Q { x } else { x - Self::Q }
75+
// Branchless conditional subtraction: if x >= Q, subtract Q; else
76+
// leave x alone. Compilers already emit `csel` here at O2, but the
77+
// explicit mask form removes the dependency on optimizer choices
78+
// and keeps the generated assembly free of secret-dependent control
79+
// flow at every optimization level.
80+
let mask = ((x >= Self::Q) as $int).wrapping_neg();
81+
x - (Self::Q & mask)
7682
}
7783

7884
fn barrett_reduce(x: Self::Long) -> Self::Int {

module-lattice/src/encoding.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ pub fn byte_decode<F: Field, D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> D
130130
let val = F::Int::truncate(x >> (D::USIZE * j));
131131
vj.0 = val & mask;
132132

133-
// Special case for FIPS 203
133+
// Special case for FIPS 203. For 12-bit values (max 4095) with Q = 3329,
134+
// the masked value is always in [0, 2Q), so `small_reduce` is exact and
135+
// avoids the hardware UDIV that `% F::Q` would emit.
134136
if D::USIZE == 12 {
135-
vj.0 = vj.0 % F::Q;
137+
vj.0 = F::small_reduce(vj.0);
136138
}
137139
}
138140
}

0 commit comments

Comments
 (0)