Skip to content

Commit 3e67c72

Browse files
ruvnetReuven
andauthored
fix(training): WASM contrastive loss + NAPI optimizer step (#339)
ADR-145: Fix training pipeline issues across WASM and NAPI bindings. WASM (ruvector-attention-wasm): - Replace serde_wasm_bindgen deserialization of negatives param with explicit js_sys::Float32Array conversion. TypedArrays don't deserialize via serde — use js_sys::Array iteration instead. NAPI (ruvector-attention-node): - Add stepInPlace() to SGD, Adam, AdamW optimizers for zero-copy in-place parameter mutation via Float32Array's AsMut<[f32]> - Document that step() returns a NEW array (callers must use return) Note: LoRA B=0 initialization in learning-wasm is correct by design (Hu et al. 2021) — documented in ADR-145, no code change needed. Co-authored-by: Reuven <cohen@ruv-mac-mini.local>
1 parent 3829c85 commit 3e67c72

3 files changed

Lines changed: 221 additions & 14 deletions

File tree

crates/ruvector-attention-node/src/training.rs

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,16 @@ impl SGDOptimizer {
257257
}
258258
}
259259

260-
/// Perform an optimization step
260+
/// Perform an optimization step, returning a **new** `Float32Array`.
261261
///
262-
/// # Arguments
263-
/// * `params` - Parameter array
264-
/// * `gradients` - Gradient array
262+
/// The input `params` buffer is consumed and a fresh array is returned with
263+
/// the updated values. Callers **must** use the return value:
265264
///
266-
/// # Returns
267-
/// Updated parameter array
265+
/// ```js
266+
/// params = optimizer.step(params, gradients);
267+
/// ```
268+
///
269+
/// If you want to mutate the JS buffer in-place instead, use `stepInPlace`.
268270
#[napi]
269271
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
270272
let mut params_vec = params.to_vec();
@@ -273,6 +275,22 @@ impl SGDOptimizer {
273275
Float32Array::new(params_vec)
274276
}
275277

278+
/// Perform an optimization step **in-place** on the underlying JS buffer.
279+
///
280+
/// This mutates the `Float32Array`'s backing `ArrayBuffer` directly, so the
281+
/// caller's original typed-array view reflects the updated weights without
282+
/// needing to capture a return value:
283+
///
284+
/// ```js
285+
/// optimizer.stepInPlace(params, gradients); // params is mutated
286+
/// ```
287+
#[napi]
288+
pub fn step_in_place(&mut self, mut params: Float32Array, gradients: Float32Array) {
289+
let gradients_slice = gradients.as_ref();
290+
let params_slice = params.as_mut();
291+
self.inner.step(params_slice, gradients_slice);
292+
}
293+
276294
/// Reset optimizer state
277295
#[napi]
278296
pub fn reset(&mut self) {
@@ -339,10 +357,16 @@ impl AdamOptimizer {
339357
}
340358
}
341359

342-
/// Perform an optimization step
360+
/// Perform an optimization step, returning a **new** `Float32Array`.
343361
///
344-
/// # Returns
345-
/// Updated parameter array
362+
/// The input `params` buffer is consumed and a fresh array is returned with
363+
/// the updated values. Callers **must** use the return value:
364+
///
365+
/// ```js
366+
/// params = optimizer.step(params, gradients);
367+
/// ```
368+
///
369+
/// If you want to mutate the JS buffer in-place instead, use `stepInPlace`.
346370
#[napi]
347371
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
348372
let mut params_vec = params.to_vec();
@@ -351,6 +375,22 @@ impl AdamOptimizer {
351375
Float32Array::new(params_vec)
352376
}
353377

378+
/// Perform an optimization step **in-place** on the underlying JS buffer.
379+
///
380+
/// This mutates the `Float32Array`'s backing `ArrayBuffer` directly, so the
381+
/// caller's original typed-array view reflects the updated weights without
382+
/// needing to capture a return value:
383+
///
384+
/// ```js
385+
/// optimizer.stepInPlace(params, gradients); // params is mutated
386+
/// ```
387+
#[napi]
388+
pub fn step_in_place(&mut self, mut params: Float32Array, gradients: Float32Array) {
389+
let gradients_slice = gradients.as_ref();
390+
let params_slice = params.as_mut();
391+
self.inner.step(params_slice, gradients_slice);
392+
}
393+
354394
/// Reset optimizer state (momentum terms)
355395
#[napi]
356396
pub fn reset(&mut self) {
@@ -411,10 +451,16 @@ impl AdamWOptimizer {
411451
}
412452
}
413453

414-
/// Perform an optimization step
454+
/// Perform an optimization step, returning a **new** `Float32Array`.
415455
///
416-
/// # Returns
417-
/// Updated parameter array
456+
/// The input `params` buffer is consumed and a fresh array is returned with
457+
/// the updated values. Callers **must** use the return value:
458+
///
459+
/// ```js
460+
/// params = optimizer.step(params, gradients);
461+
/// ```
462+
///
463+
/// If you want to mutate the JS buffer in-place instead, use `stepInPlace`.
418464
#[napi]
419465
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
420466
let mut params_vec = params.to_vec();
@@ -423,6 +469,22 @@ impl AdamWOptimizer {
423469
Float32Array::new(params_vec)
424470
}
425471

472+
/// Perform an optimization step **in-place** on the underlying JS buffer.
473+
///
474+
/// This mutates the `Float32Array`'s backing `ArrayBuffer` directly, so the
475+
/// caller's original typed-array view reflects the updated weights without
476+
/// needing to capture a return value:
477+
///
478+
/// ```js
479+
/// optimizer.stepInPlace(params, gradients); // params is mutated
480+
/// ```
481+
#[napi]
482+
pub fn step_in_place(&mut self, mut params: Float32Array, gradients: Float32Array) {
483+
let gradients_slice = gradients.as_ref();
484+
let params_slice = params.as_mut();
485+
self.inner.step(params_slice, gradients_slice);
486+
}
487+
426488
/// Reset optimizer state
427489
#[napi]
428490
pub fn reset(&mut self) {

crates/ruvector-attention-wasm/src/training.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ impl WasmInfoNCELoss {
3232
positive: &[f32],
3333
negatives: JsValue,
3434
) -> Result<f32, JsError> {
35-
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
35+
let array = js_sys::Array::from(&negatives);
36+
let mut negatives_vec: Vec<Vec<f32>> = Vec::with_capacity(array.length() as usize);
37+
for i in 0..array.length() {
38+
let typed_arr = js_sys::Float32Array::new(&array.get(i));
39+
negatives_vec.push(typed_arr.to_vec());
40+
}
3641
let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
37-
3842
Ok(self.inner.compute(anchor, positive, &negatives_refs))
3943
}
4044
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# ADR-145: WASM/NAPI Training Pipeline Fixes
2+
3+
**Status**: Accepted
4+
**Date**: 2026-04-06
5+
**Authors**: Claude Code (Opus 4.6)
6+
**Supersedes**: None
7+
**Related**: ADR-144 (Monorepo Quality Analysis Strategy)
8+
9+
---
10+
11+
## Context
12+
13+
The WASM and NAPI training pipeline spans two crate pairs:
14+
- `ruvector-learning-wasm` — MicroLoRA adaptation (WASM)
15+
- `ruvector-attention-wasm` — Contrastive loss + optimizers (WASM)
16+
- `ruvector-attention-node` — Contrastive loss + optimizers (NAPI/Node.js)
17+
18+
Three issues were reported that prevent the training pipeline from producing meaningful adaptation:
19+
20+
1. LoRA weights initialize to zero, producing identity transforms
21+
2. `computeContrastiveLoss` has a type mismatch in the WASM binding
22+
3. `optimizerStep` has a Buffer reference issue in the NAPI bridge
23+
24+
---
25+
26+
## Decision
27+
28+
### Issue 1: LoRA Zero Initialization — NOT A BUG
29+
30+
**File**: `crates/ruvector-learning-wasm/src/lora.rs:62-93`
31+
32+
The B matrix is initialized to zeros (line 83) while A is initialized with Kaiming-like scaling (lines 66-80). This produces an identity transform on the first forward pass: `output = input + alpha * (input @ A @ 0) = input`.
33+
34+
**This is correct LoRA design** per Hu et al. (2021). The LoRA paper specifies:
35+
- A is initialized with random Gaussian
36+
- B is initialized to zero
37+
- The initial delta is zero, so the pre-trained model is preserved at the start of fine-tuning
38+
39+
The `adapt()` method (lines 148-179) updates B via outer-product gradient updates. After one or more `adapt()` calls, the forward pass produces non-trivial outputs. The existing test at line 523 explicitly verifies this: output differs from input after adaptation.
40+
41+
**Action**: No code change. Document in the npm package README that `adapt()` or `adapt_with_reward()` must be called before the LoRA produces non-identity transforms.
42+
43+
### Issue 2: WASM Contrastive Loss Type Mismatch — REAL BUG
44+
45+
**File**: `crates/ruvector-attention-wasm/src/training.rs:29-39`
46+
47+
```rust
48+
// CURRENT (broken): negatives param is untyped JsValue
49+
pub fn compute(
50+
&self,
51+
anchor: &[f32],
52+
positive: &[f32],
53+
negatives: JsValue, // ← Problem: JS Float32Array[] doesn't deserialize to Vec<Vec<f32>>
54+
) -> Result<f32, JsError> {
55+
let negatives_vec: Vec<Vec<f32>> = serde_wasm_bindgen::from_value(negatives)?;
56+
// ...
57+
}
58+
```
59+
60+
When JS passes `Float32Array[]`, `serde_wasm_bindgen::from_value` fails because `Float32Array` is a TypedArray with an `ArrayBuffer` backing, not a regular JS Array of numbers. The deserializer sees a TypedArray and cannot convert it to `Vec<f32>`.
61+
62+
The NAPI binding (`ruvector-attention-node/src/training.rs:53-66`) handles this correctly using native `Vec<Float32Array>` type.
63+
64+
**Fix**: Convert each `Float32Array` element explicitly via `js_sys::Float32Array` before collecting into `Vec<Vec<f32>>`.
65+
66+
### Issue 3: NAPI Optimizer Step Buffer Reference — DESIGN BUG
67+
68+
**Files**: `crates/ruvector-attention-node/src/training.rs:269,347,419`
69+
70+
```rust
71+
// CURRENT: consumes params, returns new allocation
72+
pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
73+
let mut params_vec = params.to_vec(); // ← copies data from Buffer
74+
let gradients_slice = gradients.as_ref();
75+
self.inner.step(&mut params_vec, gradients_slice);
76+
Float32Array::new(params_vec) // ← allocates new Buffer, original is dropped
77+
}
78+
```
79+
80+
The `step()` method takes `Float32Array` by value, copies to a Vec, mutates the copy, and returns a new `Float32Array` backed by a new Buffer allocation. This means:
81+
- The caller's original Buffer reference is invalidated (consumed by the NAPI bridge)
82+
- Each step allocates and deallocates a Buffer (GC pressure)
83+
- Callers expecting in-place mutation of their typed array see no change
84+
85+
The Rust `Optimizer::step()` trait method operates on `&mut [f32]` (in-place), but the NAPI binding doesn't expose this correctly.
86+
87+
**Fix**: Use `Buffer` or `&mut [f32]` semantics to mutate in-place, or clearly document the copy-return pattern so callers assign the return value.
88+
89+
---
90+
91+
## Affected Files
92+
93+
### Crate: `ruvector-attention-wasm`
94+
95+
| File | Change | Priority |
96+
|------|--------|----------|
97+
| `src/training.rs:29-39` | Replace `JsValue` negatives param with explicit `Float32Array` array handling via `js_sys` | Critical |
98+
99+
### Crate: `ruvector-attention-node`
100+
101+
| File | Change | Priority |
102+
|------|--------|----------|
103+
| `src/training.rs:269` | `SGDOptimizer::step` — document copy-return or switch to in-place mutation | High |
104+
| `src/training.rs:347` | `AdamOptimizer::step` — same fix | High |
105+
| `src/training.rs:419` | `AdamWOptimizer::step` — same fix | High |
106+
107+
### Crate: `ruvector-learning-wasm`
108+
109+
| File | Change | Priority |
110+
|------|--------|----------|
111+
| `src/lora.rs` | No code change — add documentation clarifying B=0 is by design | Low |
112+
113+
---
114+
115+
## Consequences
116+
117+
### Positive
118+
119+
- **Contrastive loss becomes usable from JS**: Float32Array[] inputs will correctly deserialize
120+
- **Optimizer step semantics become clear**: Either in-place mutation or documented copy-return
121+
- **LoRA misconception resolved**: Documented that identity-on-init is correct LoRA behavior
122+
123+
### Negative
124+
125+
- **WASM API signature change**: `compute()` parameter type changes from `JsValue` to explicit typed array handling — breaking change for any existing callers
126+
- **NAPI optimizer API may change**: If switching to in-place mutation, callers that rely on the return value need updating
127+
128+
### Risks
129+
130+
| Risk | Likelihood | Impact | Mitigation |
131+
|------|------------|--------|------------|
132+
| WASM API break affects downstream | Low | Medium | This API was broken anyway (always errored on Float32Array[]) |
133+
| In-place mutation causes NAPI safety issues | Medium | Low | Use `Buffer::from_mut` or `Ref<Float32Array>` |
134+
135+
---
136+
137+
## References
138+
139+
- Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models" (2021) — B=0 initialization
140+
- wasm-bindgen TypedArray handling: https://docs.rs/js-sys/latest/js_sys/struct.Float32Array.html
141+
- NAPI-RS Buffer semantics: https://napi.rs/docs/concepts/external

0 commit comments

Comments
 (0)