Skip to content

Commit 12c9308

Browse files
[python] Bind cudaq.measure_handle, rewire AST bridge, add tests
Python-side counterpart to PR 3b's C++ frontend rewire, per spec cudaq-spec/proposals/measure_handle.bs (Python API; Operational Semantics; IR Representation). Bindings and host stubs - python/runtime/mlir/py_register_dialects.cpp: bind `!cc.measure_handle` so the AST bridge can construct, check, and emit the type from Python. - python/cudaq/kernel_types.py: `cudaq.measure_handle` kernel-type stub (mirrors `qubit` / `qvector` / `qview`); host-scope construction raises a device-only RuntimeError. - python/cudaq/__init__.py: re-export `cudaq.measure_handle` and add a host stub for `cudaq.to_bools(handles)` raising the same device-only error -- kernel-side calls are intercepted by the AST bridge. - python/cudaq/kernel/utils.py: `containsMeasureHandle(ty)`, mirroring `cudaq::cc::containsMeasureHandle` from PR 3b (`lib/Optimizer/Dialect/CC/CCTypes.cpp`); used by the bridge's boundary check. AST bridge (`PyASTBridge` in `python/cudaq/kernel/ast_bridge.py`) - mz / mx / my emit `quake.{mz,mx,my}` producing `!cc.measure_handle` (scalar) or `!cc.stdvec<!cc.measure_handle>` (vector). Discriminate insertion is deferred to coercion sites. - `__discriminateIfMeasureHandle` inserts `quake.discriminate` at every spec-listed coercion site: arithmetic-to-bool (if / while / not / and / or), `changeOperandToType` (i1 + i8 + stdvec<i1>), Compare (==/!=), explicit `bool(...)`, IfExp test, Assert test; surfaces the `discriminating an unbound measure_handle` diagnostic for the default-constructed pattern. - `cudaq.measure_handle()` -> `cc.UndefOp(!cc.measure_handle)` so the unbound-handle diagnostic has a recognizable source. - `cudaq.to_bools(handles)` -> vectorized `quake.discriminate` on `!cc.stdvec<!cc.measure_handle>`; legacy `cudaq.to_integer(mz(qvec))` is rejected with a targeted diagnostic matching PR 3b's C++ rejection. - Boundary check at entry-point creation: walk the kernel signature with `containsMeasureHandle` and reject any parameter / return position transitively containing a `measure_handle`. Diagnostic matches PR 3b's spec-canonical wording. Tests - python/tests/kernel/test_measure_handle.py: new spec-coverage suite (host-scope rejections, scalar/vector emission shape, every bool-coercion site, `to_bools` lowering, `to_integer(to_bools(...))` composition, unbound-handle diagnostic, boundary diagnostic). - python/tests/mlir/{bug_1775,bug_1777,bug_1875,call_qpu}.py: type-rename and structural CHECK refresh for the new bridge IR shape. - python/tests/kernel/test_{assignments,kernel_features,run_kernel, to_integer,kernel_shift_operators}.py: pytest migrations off the implicit `mz`-returns-bool assumption. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
1 parent e11d97d commit 12c9308

15 files changed

Lines changed: 725 additions & 76 deletions

python/cudaq/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _isinstance(other, _cls=cls, _isinst=py_isinstance):
204204
parallel = cudaq_runtime.parallel
205205

206206
# Primitive Types (stubs; used only in kernels, parsed to MLIR)
207-
from .kernel_types import qubit, qvector, qview
207+
from .kernel_types import measure_handle, qubit, qvector, qview
208208

209209
Pauli = cudaq_runtime.Pauli
210210
Kernel = PyKernel
@@ -319,6 +319,17 @@ def amplitudes(array_data):
319319
return numpy.array(array_data, dtype=complex())
320320

321321

322+
def to_bools(handles):
323+
"""Bulk-discriminate a ``list[cudaq.measure_handle]`` into a
324+
``list[bool]``. Device-only: this Python symbol exists so kernel
325+
code can call ``cudaq.to_bools(...)``; the AST bridge intercepts
326+
the call and lowers it to a vector form ``quake.discriminate`` on
327+
``!cc.stdvec<!cc.measure_handle>``. Host-side invocation raises a
328+
``RuntimeError``.
329+
"""
330+
raise RuntimeError("device-only; usable only inside @cudaq.kernel")
331+
332+
322333
def __clearKernelRegistries():
323334
global globalRegisteredOperations
324335
globalRegisteredOperations.clear()

python/cudaq/kernel/ast_bridge.py

Lines changed: 198 additions & 23 deletions
Large diffs are not rendered by default.

python/cudaq/kernel/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from cudaq.mlir.ir import (ComplexType, F32Type, F64Type, IntegerType, Context,
2525
Module)
2626
from cudaq.mlir._mlir_libs._quakeDialects import register_all_dialects
27-
from cudaq.kernel_types import qubit, qvector, qview
27+
from cudaq.kernel_types import measure_handle, qubit, qvector, qview
2828

2929
State = cudaq_runtime.State
3030
pauli_word = cudaq_runtime.pauli_word
@@ -41,6 +41,23 @@
4141
globalRegisteredTypes = cudaq_runtime.DataClassRegistry
4242

4343

44+
def containsMeasureHandle(ty):
45+
"""Return True iff ``ty`` is ``!cc.measure_handle`` or transitively
46+
contains one.
47+
"""
48+
if cc.MeasureHandleType.isinstance(ty):
49+
return True
50+
if cc.PointerType.isinstance(ty):
51+
return containsMeasureHandle(cc.PointerType.getElementType(ty))
52+
if cc.ArrayType.isinstance(ty):
53+
return containsMeasureHandle(cc.ArrayType.getElementType(ty))
54+
if cc.StdvecType.isinstance(ty):
55+
return containsMeasureHandle(cc.StdvecType.getElementType(ty))
56+
if cc.StructType.isinstance(ty):
57+
return any(containsMeasureHandle(t) for t in cc.StructType.getTypes(ty))
58+
return False
59+
60+
4461
def getMLIRContext():
4562
"""
4663
This code creates an MLIRContext singleton for this python process. We do
@@ -383,6 +400,8 @@ def emitFatalErrorOverride(msg):
383400
return quake.RefType.get()
384401
if annotation.attr == 'pauli_word':
385402
return cc.CharspanType.get()
403+
if annotation.attr == 'measure_handle':
404+
return cc.MeasureHandleType.get()
386405

387406
if annotation.value.id in ['numpy', 'np']:
388407
if annotation.attr in ['array', 'ndarray']:
@@ -675,6 +694,8 @@ def mlirTypeFromPyType(argType, ctx, **kwargs):
675694
return quake.RefType.get(ctx)
676695
if argType == pauli_word:
677696
return cc.CharspanType.get(ctx)
697+
if argType == measure_handle:
698+
return cc.MeasureHandleType.get(ctx)
678699

679700
if 'argInstance' in kwargs:
680701
argInstance = kwargs['argInstance']
@@ -754,6 +775,9 @@ def mlirTypeToPyType(argType):
754775
if cc.CharspanType.isinstance(argType):
755776
return pauli_word
756777

778+
if cc.MeasureHandleType.isinstance(argType):
779+
return measure_handle
780+
757781
if cc.StdvecType.isinstance(argType):
758782
eleTy = cc.StdvecType.getElementType(argType)
759783
if cc.CharspanType.isinstance(eleTy):

python/cudaq/kernel_types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,29 @@ def __new__(cls, *args, **kwargs):
3535
raise KernelTypeError(cls)
3636

3737

38+
class measure_handle(KernelType):
39+
"""
40+
A handle to a measurement event recorded inside a CUDA-Q kernel.
41+
42+
Returned by ``mz`` / ``mx`` / ``my`` inside an ``@cudaq.kernel`` body
43+
(scalar form on a single qubit; vector form on a ``qvector`` /
44+
``qview``). The classical outcome is read by coercing the handle to
45+
``bool`` in any Python ``bool`` context, and the AST bridge inserts a
46+
``quake.discriminate`` at the coercion site.
47+
``cudaq.to_bools(handles)`` is the bulk counterpart on a
48+
``list[measure_handle]``.
49+
50+
Instantiating ``cudaq.measure_handle()`` at host scope raises
51+
``RuntimeError`` (it is device-only).
52+
"""
53+
54+
def __new__(cls, *args, **kwargs):
55+
raise RuntimeError("device-only; usable only inside @cudaq.kernel")
56+
57+
def __init__(self) -> None:
58+
...
59+
60+
3861
class qubit(KernelType):
3962
"""
4063
The qubit is the primary unit of information in a quantum computer.

python/runtime/mlir/py_register_dialects.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ static void registerCCDialectAndTypes(nanobind::module_ &m) {
200200
},
201201
nanobind::arg("cls"), nanobind::arg("context") = nanobind::none());
202202

203+
mlir_type_subclass(ccMod, "MeasureHandleType",
204+
[](MlirType type) {
205+
return mlir::isa<cudaq::cc::MeasureHandleType>(
206+
unwrap(type));
207+
})
208+
.def_classmethod(
209+
"get",
210+
[](nanobind::object cls, MlirContext context) {
211+
return wrap(cudaq::cc::MeasureHandleType::get(unwrap(context)));
212+
},
213+
nanobind::arg("cls"), nanobind::arg("context") = nanobind::none());
214+
203215
mlir_type_subclass(
204216
ccMod, "StateType",
205217
[](MlirType type) { return mlir::isa<quake::StateType>(unwrap(type)); })

python/tests/kernel/test_assignments.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,25 +1150,23 @@ def test1() -> list[bool]:
11501150
assert 'variable defined in parent scope cannot be modified' in str(e.value)
11511151
assert '(offending source -> c = qs[1])' in str(e.value)
11521152

1153-
# TODO: The reason we cannot currently support this is
1154-
# because we store measurement results as values in the
1155-
# symbol table. This should be changed and supported when
1156-
# we do the change to properly distinguish measurement
1157-
# types from booleans.
1158-
with pytest.raises(RuntimeError) as e:
1159-
1160-
@cudaq.kernel
1161-
def test2() -> bool:
1162-
qs = cudaq.qvector(2)
1163-
res = mz(qs[0])
1164-
if True:
1165-
x(qs[1])
1166-
res = mz(qs[1])
1167-
return res
1153+
# Reassigning a `measure_handle`-typed variable across scopes is
1154+
# supported now that `mz` returns `cudaq.measure_handle` instead of
1155+
# `bool`: the symbol-table slot has handle type, the inner-scope
1156+
# store binds a fresh handle, and the bool-coercion at `return res`
1157+
# discriminates exactly once. Previously this case was disallowed
1158+
# because measurement results were stored as raw `i1` values in the
1159+
# symbol table.
1160+
@cudaq.kernel
1161+
def test2() -> bool:
1162+
qs = cudaq.qvector(2)
1163+
res = mz(qs[0])
1164+
if True:
1165+
x(qs[1])
1166+
res = mz(qs[1])
1167+
return res
11681168

1169-
test2()
1170-
assert 'variable defined in parent scope cannot be modified' in str(e.value)
1171-
assert '(offending source -> res = mz(qs[1]))' in str(e.value)
1169+
test2()
11721170

11731171

11741172
def test_var_scopes():
@@ -1391,7 +1389,8 @@ def fct():
13911389
x(q)
13921390

13931391
fct()
1394-
return i, mz(q)
1392+
# FIXME: aggregate-element typing does not currently auto-discriminate, so coerce explicitly.
1393+
return i, bool(mz(q))
13951394

13961395
out = cudaq.run(test1, True, False, shots_count=10)
13971396
assert all(res == (True, False) for res in out)

python/tests/kernel/test_kernel_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2727,7 +2727,7 @@ def test_mid_circuit_measurements():
27272727

27282728
@cudaq.kernel
27292729
def callee(register: cudaq.qview) -> list[bool]:
2730-
result = [0, 0, 0, 0, 0, 0, 0, 0]
2730+
result = [False, False, False, False, False, False, False, False]
27312731
for i in range(4):
27322732
j = i * 2
27332733
if i % 2 == 0:

python/tests/kernel/test_kernel_shift_operators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_run_with_integer_left_shift_operator():
7676
@cudaq.kernel
7777
def kernel(n: int) -> int:
7878
q = cudaq.qvector(n)
79-
m = mz(q)
79+
m = cudaq.to_bools(mz(q))
8080
r = 0
8181
for i in range(n):
8282
r = r & (m[i] << i)
@@ -96,7 +96,7 @@ def test_run_with_non_integer_left_shift_operator():
9696
@cudaq.kernel
9797
def kernel(n: int) -> int:
9898
q = cudaq.qvector(n)
99-
m = mz(q)
99+
m = cudaq.to_bools(mz(q))
100100
r = 0
101101
for i in range(n):
102102
r = r & (m[i] << 1.0)
@@ -113,7 +113,7 @@ def test_run_with_integer_right_shift_operator():
113113
@cudaq.kernel
114114
def kernel(n: int) -> int:
115115
q = cudaq.qvector(n)
116-
m = mz(q)
116+
m = cudaq.to_bools(mz(q))
117117
r = 0
118118
for i in range(n):
119119
r = r & (m[i] >> i)
@@ -131,7 +131,7 @@ def test_run_with_integer_bitwise_or_operator():
131131
@cudaq.kernel
132132
def kernel(n: int) -> int:
133133
q = cudaq.qvector(n)
134-
m = mz(q)
134+
m = cudaq.to_bools(mz(q))
135135
r = 0
136136
for i in range(n):
137137
r = r | (m[i] >> i)
@@ -149,7 +149,7 @@ def test_run_with_integer_bitwise_xor_operator():
149149
@cudaq.kernel
150150
def kernel(n: int) -> int:
151151
q = cudaq.qvector(n)
152-
m = mz(q)
152+
m = cudaq.to_bools(mz(q))
153153
r = 0
154154
for i in range(n):
155155
r = r ^ (m[i] >> i)

0 commit comments

Comments
 (0)