1- // RUN: %eopt --split-input-file -- enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s
1+ // RUN: %eopt --enzyme %s | FileCheck %s
22
33func.func @foo_flat (%x : f64 ) -> f64 {
44 %buf = memref.alloca () : memref <f64 >
@@ -18,12 +18,45 @@ func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 {
1818// CHECK-LABEL: func.func private @diffefoo_flat(
1919// CHECK-SAME: %[[X:[^,]+]]: f64,
2020// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 {
21- // A shadow memref.alloca must be created and zero-initialized so the
22- // reverse-mode adjoint can accumulate gradients into it.
23- // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64
24- // CHECK-DAG: %[[DBUF:.*]] = memref.alloca() : memref<f64>
25- // CHECK: memref.store %[[ZERO]], %[[DBUF]][] : memref<f64>
26- // No leftover placeholders should survive the differentiation.
21+
22+ // CHECK: %[[GX:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
23+ // CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
24+ // CHECK: %[[CS:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
25+ // CHECK: %[[CL:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
26+ // CHECK: %[[GY:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
27+ // CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
28+
29+ // CHECK: %[[DBUF:.+]] = memref.alloca() : memref<f64>
30+ // CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64
31+ // CHECK: memref.store %[[ZERO_INIT]], %[[DBUF]][] : memref<f64>
32+
33+ // CHECK: %[[BUF:.+]] = memref.alloca() : memref<f64>
34+ // CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
35+ // CHECK: memref.store %[[X]], %[[BUF]][] : memref<f64>
36+ // CHECK: "enzyme.push"(%[[CL]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
37+ // CHECK: memref.load %[[BUF]][] : memref<f64>
38+ // CHECK: cf.br ^bb1
39+ // CHECK: ^bb1:
40+
41+ // CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
42+ // CHECK: arith.addf %{{.+}}, %[[DOUT]] : f64
43+ // CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
44+
45+ // CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
46+ // CHECK: %[[POPL:.+]] = "enzyme.pop"(%[[CL]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
47+ // CHECK: memref.load %[[POPL]][] : memref<f64>
48+ // CHECK: arith.addf
49+ // CHECK: memref.store %{{.+}}, %[[POPL]][] : memref<f64>
50+
51+ // CHECK: %[[POPS:.+]] = "enzyme.pop"(%[[CS]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
52+ // CHECK: memref.load %[[POPS]][] : memref<f64>
53+ // CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
54+ // CHECK: arith.addf
55+ // CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
56+ // CHECK: %[[ZERO_CLEAR:.+]] = arith.constant 0.000000e+00 : f64
57+ // CHECK: memref.store %[[ZERO_CLEAR]], %[[POPS]][] : memref<f64>
58+
59+ // CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
60+ // CHECK: return %{{.+}} : f64
61+
2762// CHECK-NOT: enzyme.placeholder
28- // The function ultimately returns the gradient w.r.t. x (= dout).
29- // CHECK: return %{{.*}} : f64
0 commit comments