1+ // ===- EnzymeMLIR.cpp - C API for Enzyme MLIR dialect ---------------------===//
2+ //
3+ // Part of the Enzyme Project, under the Apache License v2.0 with LLVM
4+ // Exceptions.
5+ // See https://llvm.org/LICENSE.txt for license information.
6+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+ //
8+ // ===----------------------------------------------------------------------===//
9+
110#include " EnzymeMLIR.h"
211
312#include " mlir/CAPI/IR.h"
13+ #include " mlir/CAPI/Pass.h"
414#include " mlir/CAPI/Registration.h"
515
616#include " Dialect/Dialect.h"
717#include " Dialect/Impulse/Impulse.h"
818#include " Dialect/Ops.h"
19+ #include " Implementations/CoreDialectsAutoDiffImplementations.h"
20+ #include " Passes/Passes.h"
21+ #include " llvm/Support/ErrorHandling.h"
922MLIR_DEFINE_CAPI_DIALECT_REGISTRATION (Enzyme, enzyme,
1023 mlir::enzyme::EnzymeDialect)
1124
12- MlirAttribute enzymeActivityAttrGet (MlirContext ctx, EnzymeActivity activity) {
25+ MLIR_CAPI_EXPORTED void enzymeRegisterPasses (void ) {
26+ mlir::enzyme::registerDifferentiatePass ();
27+ mlir::enzyme::registerExpandImpulsePass ();
28+ mlir::enzyme::registerBatchPass ();
29+ mlir::enzyme::registerBatchDiffPass ();
30+ mlir::enzyme::registerDifferentiateWrapperPass ();
31+ mlir::enzyme::registerInlineEnzymeIntoRegionPass ();
32+ mlir::enzyme::registerOutlineEnzymeFromRegionPass ();
33+ mlir::enzyme::registerPrintActivityAnalysisPass ();
34+ mlir::enzyme::registerPrintAliasAnalysisPass ();
35+ mlir::enzyme::registerRemoveUnusedEnzymeOpsPass ();
36+ }
37+
38+ MLIR_CAPI_EXPORTED void
39+ enzymeRegisterDialectExtensions (MlirDialectRegistry registry) {
40+ mlir::enzyme::registerCoreDialectAutodiffInterfaces (*unwrap (registry));
41+ }
42+
43+ MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass (void ) {
44+ return wrap (mlir::enzyme::createDifferentiatePass ().release ());
45+ }
46+
47+ MLIR_CAPI_EXPORTED MlirPass
48+ enzymeCreateDifferentiatePassWithOptions (MlirStringRef postpasses,
49+ bool verifyPostPasses) {
50+ mlir::enzyme::DifferentiatePassOptions opts;
51+ opts.postpasses = std::string (postpasses.data , postpasses.length );
52+ opts.verifyPostPasses = verifyPostPasses;
53+ return wrap (mlir::enzyme::createDifferentiatePass (opts).release ());
54+ }
55+
56+ MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass (void ) {
57+ return wrap (mlir::enzyme::createEnzymeOpsToMemRefPass ().release ());
58+ }
59+
60+ MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass (void ) {
61+ return wrap (mlir::enzyme::createBatchPass ().release ());
62+ }
63+
64+ MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass (void ) {
65+ return wrap (mlir::enzyme::createBatchDiffPass ().release ());
66+ }
67+
68+ MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass (void ) {
69+ return wrap (mlir::enzyme::createRemoveUnusedEnzymeOpsPass ().release ());
70+ }
71+
72+ MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet (MlirContext ctx,
73+ uint32_t activity) {
1374 mlir::enzyme::Activity act;
1475 switch (activity) {
15- case EnzymeActivity_enzyme_active :
76+ case 0 :
1677 act = mlir::enzyme::Activity::enzyme_active;
1778 break ;
18- case EnzymeActivity_enzyme_dup :
79+ case 1 :
1980 act = mlir::enzyme::Activity::enzyme_dup;
2081 break ;
21- case EnzymeActivity_enzyme_const :
82+ case 2 :
2283 act = mlir::enzyme::Activity::enzyme_const;
2384 break ;
24- case EnzymeActivity_enzyme_dupnoneed :
85+ case 3 :
2586 act = mlir::enzyme::Activity::enzyme_dupnoneed;
2687 break ;
27- case EnzymeActivity_enzyme_activenoneed :
88+ case 4 :
2889 act = mlir::enzyme::Activity::enzyme_activenoneed;
2990 break ;
30- case EnzymeActivity_enzyme_constnoneed :
91+ case 5 :
3192 act = mlir::enzyme::Activity::enzyme_constnoneed;
3293 break ;
94+ default :
95+ llvm_unreachable (" invalid Enzyme activity" );
3396 }
3497 return wrap (mlir::enzyme::ActivityAttr::get (unwrap (ctx), act));
3598}
3699
37100static mlir::ArrayAttr activityArrayAttr (MlirContext ctx,
38- MlirAttribute *activity, intptr_t n) {
39- llvm::SmallVector<mlir::Attribute> attrs;
40- attrs.reserve (n);
41- for (intptr_t i = 0 ; i < n; ++i)
42- attrs.push_back (unwrap (activity[i]));
43- return mlir::ArrayAttr::get (unwrap (ctx), attrs);
44- }
45-
46- static void collectTypes (MlirType *src, intptr_t n,
47- llvm::SmallVectorImpl<mlir::Type> &out) {
48- out.reserve (n);
49- for (intptr_t i = 0 ; i < n; ++i)
50- out.push_back (unwrap (src[i]));
101+ const MlirAttribute *activity,
102+ intptr_t n) {
103+ return mlir::ArrayAttr::get (
104+ unwrap (ctx),
105+ llvm::ArrayRef<mlir::Attribute>(
106+ reinterpret_cast <const mlir::Attribute *>(activity), n));
51107}
52108
53- static void collectValues (MlirValue *src, intptr_t n,
54- llvm::SmallVectorImpl<mlir::Value> &out) {
55- out.reserve (n);
56- for (intptr_t i = 0 ; i < n; ++i)
57- out.push_back (unwrap (src[i]));
109+ MLIR_CAPI_EXPORTED MlirOperation
110+ enzymeAutoDiffOpCreate (MlirContext ctx, MlirStringRef fn,
111+ const MlirType *resultTypes, intptr_t nResults,
112+ const MlirValue *inputs, intptr_t nInputs,
113+ const MlirAttribute *activity, intptr_t nActivity,
114+ const MlirAttribute *retActivity, intptr_t nRetActivity,
115+ int64_t width, bool strongZero, MlirLocation loc) {
116+ auto op = mlir::OpBuilder (unwrap (ctx)).create <mlir::enzyme::AutoDiffOp>(
117+ unwrap (loc),
118+ mlir::TypeRange (llvm::ArrayRef<mlir::Type>(
119+ reinterpret_cast <const mlir::Type *>(resultTypes), nResults)),
120+ llvm::StringRef (fn.data , fn.length ),
121+ mlir::ValueRange (llvm::ArrayRef<mlir::Value>(
122+ reinterpret_cast <const mlir::Value *>(inputs), nInputs)),
123+ activityArrayAttr (ctx, activity, nActivity),
124+ activityArrayAttr (ctx, retActivity, nRetActivity), (uint64_t )width,
125+ strongZero);
126+ return wrap (op.getOperation ());
58127}
59128
60- MlirOperation
61- enzymeAutoDiffOpCreate (MlirContext ctx, MlirStringRef fn, MlirType *resultTypes,
62- intptr_t nResults, MlirValue *inputs, intptr_t nInputs,
63- MlirAttribute *activity, intptr_t nActivity,
64- MlirAttribute *retActivity, intptr_t nRetActivity,
65- int64_t width, bool strongZero, MlirLocation loc) {
66- auto *mlirCtx = unwrap (ctx);
67- llvm::SmallVector<mlir::Type> results;
68- collectTypes (resultTypes, nResults, results);
69- llvm::SmallVector<mlir::Value> operands;
70- collectValues (inputs, nInputs, operands);
71- auto op = mlir::OpBuilder (mlirCtx).create <mlir::enzyme::AutoDiffOp>(
72- unwrap (loc), mlir::TypeRange (results),
73- llvm::StringRef (fn.data , fn.length ), mlir::ValueRange (operands),
129+ MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate (
130+ MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
131+ intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
132+ const MlirAttribute *activity, intptr_t nActivity,
133+ const MlirAttribute *retActivity, intptr_t nRetActivity,
134+ int64_t width, bool strongZero, MlirLocation loc) {
135+ auto op = mlir::OpBuilder (unwrap (ctx)).create <mlir::enzyme::ForwardDiffOp>(
136+ unwrap (loc),
137+ mlir::TypeRange (llvm::ArrayRef<mlir::Type>(
138+ reinterpret_cast <const mlir::Type *>(resultTypes), nResults)),
139+ llvm::StringRef (fn.data , fn.length ),
140+ mlir::ValueRange (llvm::ArrayRef<mlir::Value>(
141+ reinterpret_cast <const mlir::Value *>(inputs), nInputs)),
74142 activityArrayAttr (ctx, activity, nActivity),
75143 activityArrayAttr (ctx, retActivity, nRetActivity), (uint64_t )width,
76144 strongZero);
77145 return wrap (op.getOperation ());
78146}
79147
80- MlirOperation enzymeForwardDiffOpCreate (
81- MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults,
82- MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity,
83- intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity,
148+ MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate (
149+ MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
150+ intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
151+ const MlirAttribute *activity, intptr_t nActivity,
152+ const MlirAttribute *retActivity, intptr_t nRetActivity,
84153 int64_t width, bool strongZero, MlirLocation loc) {
85- auto *mlirCtx = unwrap (ctx);
86- llvm::SmallVector<mlir::Type> results;
87- collectTypes (resultTypes, nResults, results);
88- llvm::SmallVector<mlir::Value> operands;
89- collectValues (inputs, nInputs, operands);
90- auto op = mlir::OpBuilder (mlirCtx).create <mlir::enzyme::ForwardDiffOp>(
91- unwrap (loc), mlir::TypeRange (results),
92- llvm::StringRef (fn.data , fn.length ), mlir::ValueRange (operands),
154+ auto op = mlir::OpBuilder (unwrap (ctx)).create <mlir::enzyme::JacobianOp>(
155+ unwrap (loc),
156+ mlir::TypeRange (llvm::ArrayRef<mlir::Type>(
157+ reinterpret_cast <const mlir::Type *>(resultTypes), nResults)),
158+ llvm::StringRef (fn.data , fn.length ),
159+ mlir::ValueRange (llvm::ArrayRef<mlir::Value>(
160+ reinterpret_cast <const mlir::Value *>(inputs), nInputs)),
93161 activityArrayAttr (ctx, activity, nActivity),
94162 activityArrayAttr (ctx, retActivity, nRetActivity), (uint64_t )width,
95163 strongZero);
96164 return wrap (op.getOperation ());
97165}
98166
99- MlirAttribute enzymeRngDistributionAttrGet (MlirContext ctx,
100- EnzymeRngDistribution dist) {
167+ MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate (
168+ MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
169+ intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
170+ const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc) {
171+ auto op = mlir::OpBuilder (unwrap (ctx)).create <mlir::enzyme::BatchOp>(
172+ unwrap (loc),
173+ mlir::TypeRange (llvm::ArrayRef<mlir::Type>(
174+ reinterpret_cast <const mlir::Type *>(resultTypes), nResults)),
175+ llvm::StringRef (fn.data , fn.length ),
176+ mlir::ValueRange (llvm::ArrayRef<mlir::Value>(
177+ reinterpret_cast <const mlir::Value *>(inputs), nInputs)),
178+ llvm::ArrayRef<int64_t >(batchShape, nBatchShape));
179+ return wrap (op.getOperation ());
180+ }
181+
182+ MLIR_CAPI_EXPORTED MlirAttribute
183+ enzymeRngDistributionAttrGet (MlirContext ctx, EnzymeRngDistribution dist) {
101184 mlir::impulse::RngDistribution rngDist;
102185 switch (dist) {
103186 case EnzymeRngDistribution_Uniform:
@@ -113,9 +196,10 @@ MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
113196 return wrap (mlir::impulse::RngDistributionAttr::get (unwrap (ctx), rngDist));
114197}
115198
116- MlirAttribute enzymeSupportAttrGet (MlirContext ctx, EnzymeSupportKind kind,
117- bool hasLowerBound, double lowerBound,
118- bool hasUpperBound, double upperBound) {
199+ MLIR_CAPI_EXPORTED MlirAttribute
200+ enzymeSupportAttrGet (MlirContext ctx, EnzymeSupportKind kind,
201+ bool hasLowerBound, double lowerBound,
202+ bool hasUpperBound, double upperBound) {
119203 auto *mlirCtx = unwrap (ctx);
120204
121205 mlir::impulse::SupportKind supportKind;
@@ -154,8 +238,9 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
154238 upperAttr));
155239}
156240
157- MlirAttribute enzymeHMCConfigAttrGet (MlirContext ctx, double trajectoryLength,
158- bool adaptStepSize, bool adaptMassMatrix) {
241+ MLIR_CAPI_EXPORTED MlirAttribute
242+ enzymeHMCConfigAttrGet (MlirContext ctx, double trajectoryLength,
243+ bool adaptStepSize, bool adaptMassMatrix) {
159244 auto *mlirCtx = unwrap (ctx);
160245 auto trajectoryLengthAttr =
161246 mlir::FloatAttr::get (mlir::Float64Type::get (mlirCtx), trajectoryLength);
@@ -164,10 +249,10 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
164249 mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix));
165250}
166251
167- MlirAttribute enzymeNUTSConfigAttrGet (MlirContext ctx, int64_t maxTreeDepth,
168- bool hasMaxDeltaEnergy ,
169- double maxDeltaEnergy, bool adaptStepSize ,
170- bool adaptMassMatrix) {
252+ MLIR_CAPI_EXPORTED MlirAttribute
253+ enzymeNUTSConfigAttrGet (MlirContext ctx, int64_t maxTreeDepth ,
254+ bool hasMaxDeltaEnergy, double maxDeltaEnergy,
255+ bool adaptStepSize, bool adaptMassMatrix) {
171256 auto *mlirCtx = unwrap (ctx);
172257
173258 mlir::FloatAttr maxDeltaEnergyAttr;
@@ -180,6 +265,7 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
180265 adaptMassMatrix));
181266}
182267
183- MlirAttribute enzymeSymbolAttrGet (MlirContext ctx, uint64_t ptr) {
268+ MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet (MlirContext ctx,
269+ uint64_t ptr) {
184270 return wrap (mlir::impulse::SymbolAttr::get (unwrap (ctx), ptr));
185271}
0 commit comments