Skip to content

Commit 1f42ce5

Browse files
committed
fix reverse ad for memref.alloca
1 parent 89837d0 commit 1f42ce5

3 files changed

Lines changed: 55 additions & 16 deletions

File tree

enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,12 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset,
223223
}
224224
assert(!isConstantValue(val));
225225

226+
bool isMutable = false;
227+
if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType()))
228+
isMutable = iface.isMutable();
229+
226230
if (mode == DerivativeMode::ForwardMode ||
227-
mode == DerivativeMode::ForwardModeSplit) {
231+
mode == DerivativeMode::ForwardModeSplit || isMutable) {
228232
setInvertedPointer(val, toset);
229233
}
230234
/*
@@ -240,11 +244,13 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset,
240244

241245
void mlir::enzyme::MGradientUtils::setInvertedPointer(Value val, Value toset) {
242246
assert(getShadowType(val.getType()) == toset.getType());
243-
auto found = invertedPointers.lookupOrNull(val);
244-
assert(found != nullptr);
245-
auto placeholder = found.getDefiningOp<enzyme::PlaceholderOp>();
246-
placeholder.replaceAllUsesWith(toset);
247-
erase(placeholder);
247+
248+
if (auto found = invertedPointers.lookupOrNull(val)) {
249+
if (auto placeholder = found.getDefiningOp<enzyme::PlaceholderOp>()) {
250+
placeholder.replaceAllUsesWith(toset);
251+
erase(placeholder);
252+
}
253+
}
248254
invertedPointers.map(val, toset);
249255
}
250256

enzyme/test/MLIR/ReverseMode/alloca.mlir

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %eopt --split-input-file --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s
1+
// RUN: %eopt --enzyme %s | FileCheck %s
22

33
func.func @foo_flat(%x : f64) -> f64 {
44
%buf = memref.alloca() : memref<f64>
@@ -18,12 +18,45 @@ func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 {
1818
// CHECK-LABEL: func.func private @diffefoo_flat(
1919
// CHECK-SAME: %[[X:[^,]+]]: f64,
2020
// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 {
21-
// A shadow memref.alloca must be created and zero-initialized so the
22-
// reverse-mode adjoint can accumulate gradients into it.
23-
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
24-
// CHECK-DAG: %[[DBUF:.*]] = memref.alloca() : memref<f64>
25-
// CHECK: memref.store %[[ZERO]], %[[DBUF]][] : memref<f64>
26-
// No leftover placeholders should survive the differentiation.
21+
22+
// CHECK: %[[GX:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
23+
// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
24+
// CHECK: %[[CS:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
25+
// CHECK: %[[CL:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
26+
// CHECK: %[[GY:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
27+
// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
28+
29+
// CHECK: %[[DBUF:.+]] = memref.alloca() : memref<f64>
30+
// CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64
31+
// CHECK: memref.store %[[ZERO_INIT]], %[[DBUF]][] : memref<f64>
32+
33+
// CHECK: %[[BUF:.+]] = memref.alloca() : memref<f64>
34+
// CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
35+
// CHECK: memref.store %[[X]], %[[BUF]][] : memref<f64>
36+
// CHECK: "enzyme.push"(%[[CL]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
37+
// CHECK: memref.load %[[BUF]][] : memref<f64>
38+
// CHECK: cf.br ^bb1
39+
// CHECK: ^bb1:
40+
41+
// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
42+
// CHECK: arith.addf %{{.+}}, %[[DOUT]] : f64
43+
// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
44+
45+
// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
46+
// CHECK: %[[POPL:.+]] = "enzyme.pop"(%[[CL]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
47+
// CHECK: memref.load %[[POPL]][] : memref<f64>
48+
// CHECK: arith.addf
49+
// CHECK: memref.store %{{.+}}, %[[POPL]][] : memref<f64>
50+
51+
// CHECK: %[[POPS:.+]] = "enzyme.pop"(%[[CS]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
52+
// CHECK: memref.load %[[POPS]][] : memref<f64>
53+
// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
54+
// CHECK: arith.addf
55+
// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
56+
// CHECK: %[[ZERO_CLEAR:.+]] = arith.constant 0.000000e+00 : f64
57+
// CHECK: memref.store %[[ZERO_CLEAR]], %[[POPS]][] : memref<f64>
58+
59+
// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
60+
// CHECK: return %{{.+}} : f64
61+
2762
// CHECK-NOT: enzyme.placeholder
28-
// The function ultimately returns the gradient w.r.t. x (= dout).
29-
// CHECK: return %{{.*}} : f64

enzyme/test/MLIR/ReverseMode/alloca_scope2.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func.func @foo(%x : f64) -> f64{
1414
func.func @dfoo(%x: f64, %dout : f64) -> f64 {
1515
%dx = enzyme.autodiff @foo(%x, %dout) {
1616
activity = [#enzyme<activity enzyme_active>],
17-
ret_activity = [#enzyme<acitivity enzyme_activenoneed>]
17+
ret_activity = [#enzyme<activity enzyme_activenoneed>]
1818
} : (f64, f64) -> (f64)
1919
return %dx : f64
2020
}

0 commit comments

Comments
 (0)