Skip to content

Commit 9044766

Browse files
committed
[MLIR] add pass creation fns, batch op constructor, MLIR_CAPI_EXPORTED on all definitions
1 parent 9dc699c commit 9044766

3 files changed

Lines changed: 248 additions & 79 deletions

File tree

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
add_mlir_public_c_api_library(MLIRCAPIEnzyme
22
EnzymeMLIR.cpp
33

4+
DEPENDS
5+
MLIRImpulseEnumsIncGen
6+
MLIRImpulseAttributesIncGen
7+
48
LINK_LIBS PUBLIC
59
MLIRIR
10+
MLIRCAPIIR
11+
12+
LINK_LIBS PRIVATE
613
MLIREnzyme
14+
MLIRImpulse
15+
MLIREnzymeTransforms
16+
MLIREnzymeAnalysis
17+
MLIREnzymeAutoDiffInterface
18+
MLIREnzymeImplementations
719
)
Lines changed: 149 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,186 @@
1+
//===- EnzymeMLIR.cpp - C API for Enzyme MLIR dialect ---------------------===//
2+
//
3+
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
110
#include "EnzymeMLIR.h"
211

312
#include "mlir/CAPI/IR.h"
13+
#include "mlir/CAPI/Pass.h"
414
#include "mlir/CAPI/Registration.h"
515

616
#include "Dialect/Dialect.h"
717
#include "Dialect/Impulse/Impulse.h"
818
#include "Dialect/Ops.h"
19+
#include "Implementations/CoreDialectsAutoDiffImplementations.h"
20+
#include "Passes/Passes.h"
21+
#include "llvm/Support/ErrorHandling.h"
922
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme,
1023
mlir::enzyme::EnzymeDialect)
1124

12-
MlirAttribute enzymeActivityAttrGet(MlirContext ctx, EnzymeActivity activity) {
25+
MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void) {
26+
mlir::enzyme::registerDifferentiatePass();
27+
mlir::enzyme::registerExpandImpulsePass();
28+
mlir::enzyme::registerBatchPass();
29+
mlir::enzyme::registerBatchDiffPass();
30+
mlir::enzyme::registerDifferentiateWrapperPass();
31+
mlir::enzyme::registerInlineEnzymeIntoRegionPass();
32+
mlir::enzyme::registerOutlineEnzymeFromRegionPass();
33+
mlir::enzyme::registerPrintActivityAnalysisPass();
34+
mlir::enzyme::registerPrintAliasAnalysisPass();
35+
mlir::enzyme::registerRemoveUnusedEnzymeOpsPass();
36+
}
37+
38+
MLIR_CAPI_EXPORTED void
39+
enzymeRegisterDialectExtensions(MlirDialectRegistry registry) {
40+
mlir::enzyme::registerCoreDialectAutodiffInterfaces(*unwrap(registry));
41+
}
42+
43+
MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void) {
44+
return wrap(mlir::enzyme::createDifferentiatePass().release());
45+
}
46+
47+
MLIR_CAPI_EXPORTED MlirPass
48+
enzymeCreateDifferentiatePassWithOptions(MlirStringRef postpasses,
49+
bool verifyPostPasses) {
50+
mlir::enzyme::DifferentiatePassOptions opts;
51+
opts.postpasses = std::string(postpasses.data, postpasses.length);
52+
opts.verifyPostPasses = verifyPostPasses;
53+
return wrap(mlir::enzyme::createDifferentiatePass(opts).release());
54+
}
55+
56+
MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void) {
57+
return wrap(mlir::enzyme::createEnzymeOpsToMemRefPass().release());
58+
}
59+
60+
MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void) {
61+
return wrap(mlir::enzyme::createBatchPass().release());
62+
}
63+
64+
MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void) {
65+
return wrap(mlir::enzyme::createBatchDiffPass().release());
66+
}
67+
68+
MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void) {
69+
return wrap(mlir::enzyme::createRemoveUnusedEnzymeOpsPass().release());
70+
}
71+
72+
MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx,
73+
uint32_t activity) {
1374
mlir::enzyme::Activity act;
1475
switch (activity) {
15-
case EnzymeActivity_enzyme_active:
76+
case 0:
1677
act = mlir::enzyme::Activity::enzyme_active;
1778
break;
18-
case EnzymeActivity_enzyme_dup:
79+
case 1:
1980
act = mlir::enzyme::Activity::enzyme_dup;
2081
break;
21-
case EnzymeActivity_enzyme_const:
82+
case 2:
2283
act = mlir::enzyme::Activity::enzyme_const;
2384
break;
24-
case EnzymeActivity_enzyme_dupnoneed:
85+
case 3:
2586
act = mlir::enzyme::Activity::enzyme_dupnoneed;
2687
break;
27-
case EnzymeActivity_enzyme_activenoneed:
88+
case 4:
2889
act = mlir::enzyme::Activity::enzyme_activenoneed;
2990
break;
30-
case EnzymeActivity_enzyme_constnoneed:
91+
case 5:
3192
act = mlir::enzyme::Activity::enzyme_constnoneed;
3293
break;
94+
default:
95+
llvm_unreachable("invalid Enzyme activity");
3396
}
3497
return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx), act));
3598
}
3699

37100
static mlir::ArrayAttr activityArrayAttr(MlirContext ctx,
38-
MlirAttribute *activity, intptr_t n) {
39-
llvm::SmallVector<mlir::Attribute> attrs;
40-
attrs.reserve(n);
41-
for (intptr_t i = 0; i < n; ++i)
42-
attrs.push_back(unwrap(activity[i]));
43-
return mlir::ArrayAttr::get(unwrap(ctx), attrs);
44-
}
45-
46-
static void collectTypes(MlirType *src, intptr_t n,
47-
llvm::SmallVectorImpl<mlir::Type> &out) {
48-
out.reserve(n);
49-
for (intptr_t i = 0; i < n; ++i)
50-
out.push_back(unwrap(src[i]));
101+
const MlirAttribute *activity,
102+
intptr_t n) {
103+
return mlir::ArrayAttr::get(
104+
unwrap(ctx),
105+
llvm::ArrayRef<mlir::Attribute>(
106+
reinterpret_cast<const mlir::Attribute *>(activity), n));
51107
}
52108

53-
static void collectValues(MlirValue *src, intptr_t n,
54-
llvm::SmallVectorImpl<mlir::Value> &out) {
55-
out.reserve(n);
56-
for (intptr_t i = 0; i < n; ++i)
57-
out.push_back(unwrap(src[i]));
109+
MLIR_CAPI_EXPORTED MlirOperation
110+
enzymeAutoDiffOpCreate(MlirContext ctx, MlirStringRef fn,
111+
const MlirType *resultTypes, intptr_t nResults,
112+
const MlirValue *inputs, intptr_t nInputs,
113+
const MlirAttribute *activity, intptr_t nActivity,
114+
const MlirAttribute *retActivity, intptr_t nRetActivity,
115+
int64_t width, bool strongZero, MlirLocation loc) {
116+
auto op = mlir::OpBuilder(unwrap(ctx)).create<mlir::enzyme::AutoDiffOp>(
117+
unwrap(loc),
118+
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
119+
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
120+
llvm::StringRef(fn.data, fn.length),
121+
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
122+
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
123+
activityArrayAttr(ctx, activity, nActivity),
124+
activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width,
125+
strongZero);
126+
return wrap(op.getOperation());
58127
}
59128

60-
MlirOperation
61-
enzymeAutoDiffOpCreate(MlirContext ctx, MlirStringRef fn, MlirType *resultTypes,
62-
intptr_t nResults, MlirValue *inputs, intptr_t nInputs,
63-
MlirAttribute *activity, intptr_t nActivity,
64-
MlirAttribute *retActivity, intptr_t nRetActivity,
65-
int64_t width, bool strongZero, MlirLocation loc) {
66-
auto *mlirCtx = unwrap(ctx);
67-
llvm::SmallVector<mlir::Type> results;
68-
collectTypes(resultTypes, nResults, results);
69-
llvm::SmallVector<mlir::Value> operands;
70-
collectValues(inputs, nInputs, operands);
71-
auto op = mlir::OpBuilder(mlirCtx).create<mlir::enzyme::AutoDiffOp>(
72-
unwrap(loc), mlir::TypeRange(results),
73-
llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands),
129+
MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate(
130+
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
131+
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
132+
const MlirAttribute *activity, intptr_t nActivity,
133+
const MlirAttribute *retActivity, intptr_t nRetActivity,
134+
int64_t width, bool strongZero, MlirLocation loc) {
135+
auto op = mlir::OpBuilder(unwrap(ctx)).create<mlir::enzyme::ForwardDiffOp>(
136+
unwrap(loc),
137+
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
138+
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
139+
llvm::StringRef(fn.data, fn.length),
140+
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
141+
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
74142
activityArrayAttr(ctx, activity, nActivity),
75143
activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width,
76144
strongZero);
77145
return wrap(op.getOperation());
78146
}
79147

80-
MlirOperation enzymeForwardDiffOpCreate(
81-
MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults,
82-
MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity,
83-
intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity,
148+
MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate(
149+
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
150+
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
151+
const MlirAttribute *activity, intptr_t nActivity,
152+
const MlirAttribute *retActivity, intptr_t nRetActivity,
84153
int64_t width, bool strongZero, MlirLocation loc) {
85-
auto *mlirCtx = unwrap(ctx);
86-
llvm::SmallVector<mlir::Type> results;
87-
collectTypes(resultTypes, nResults, results);
88-
llvm::SmallVector<mlir::Value> operands;
89-
collectValues(inputs, nInputs, operands);
90-
auto op = mlir::OpBuilder(mlirCtx).create<mlir::enzyme::ForwardDiffOp>(
91-
unwrap(loc), mlir::TypeRange(results),
92-
llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands),
154+
auto op = mlir::OpBuilder(unwrap(ctx)).create<mlir::enzyme::JacobianOp>(
155+
unwrap(loc),
156+
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
157+
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
158+
llvm::StringRef(fn.data, fn.length),
159+
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
160+
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
93161
activityArrayAttr(ctx, activity, nActivity),
94162
activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width,
95163
strongZero);
96164
return wrap(op.getOperation());
97165
}
98166

99-
MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
100-
EnzymeRngDistribution dist) {
167+
MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate(
168+
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
169+
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
170+
const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc) {
171+
auto op = mlir::OpBuilder(unwrap(ctx)).create<mlir::enzyme::BatchOp>(
172+
unwrap(loc),
173+
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
174+
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
175+
llvm::StringRef(fn.data, fn.length),
176+
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
177+
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
178+
llvm::ArrayRef<int64_t>(batchShape, nBatchShape));
179+
return wrap(op.getOperation());
180+
}
181+
182+
MLIR_CAPI_EXPORTED MlirAttribute
183+
enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) {
101184
mlir::impulse::RngDistribution rngDist;
102185
switch (dist) {
103186
case EnzymeRngDistribution_Uniform:
@@ -113,9 +196,10 @@ MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
113196
return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist));
114197
}
115198

116-
MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
117-
bool hasLowerBound, double lowerBound,
118-
bool hasUpperBound, double upperBound) {
199+
MLIR_CAPI_EXPORTED MlirAttribute
200+
enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
201+
bool hasLowerBound, double lowerBound,
202+
bool hasUpperBound, double upperBound) {
119203
auto *mlirCtx = unwrap(ctx);
120204

121205
mlir::impulse::SupportKind supportKind;
@@ -154,8 +238,9 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
154238
upperAttr));
155239
}
156240

157-
MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
158-
bool adaptStepSize, bool adaptMassMatrix) {
241+
MLIR_CAPI_EXPORTED MlirAttribute
242+
enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
243+
bool adaptStepSize, bool adaptMassMatrix) {
159244
auto *mlirCtx = unwrap(ctx);
160245
auto trajectoryLengthAttr =
161246
mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength);
@@ -164,10 +249,10 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
164249
mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix));
165250
}
166251

167-
MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
168-
bool hasMaxDeltaEnergy,
169-
double maxDeltaEnergy, bool adaptStepSize,
170-
bool adaptMassMatrix) {
252+
MLIR_CAPI_EXPORTED MlirAttribute
253+
enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
254+
bool hasMaxDeltaEnergy, double maxDeltaEnergy,
255+
bool adaptStepSize, bool adaptMassMatrix) {
171256
auto *mlirCtx = unwrap(ctx);
172257

173258
mlir::FloatAttr maxDeltaEnergyAttr;
@@ -180,6 +265,7 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
180265
adaptMassMatrix));
181266
}
182267

183-
MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) {
268+
MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx,
269+
uint64_t ptr) {
184270
return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr));
185271
}

0 commit comments

Comments
 (0)