Skip to content

Commit 7fdec04

Browse files
Improve accuracy for models using shuffle, unshuffle, cat ops (#19159)
Summary: Replace the Qualcomm concat observer path with an explicit same-domain-or-requantize model for `aten.cat`. Preserve shared qparams for `pixel_shuffle` and `pixel_unshuffle`, extend `split_with_sizes_copy` coverage, and add regressions for mismatched `cat` branches plus value-preserving ops that must use `SharedQuantizationSpec`. Differential Revision: D102626539
1 parent ddd8ac6 commit 7fdec04

10 files changed

Lines changed: 573 additions & 193 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .annotate_avg_pool1d import AnnotateAvgPool1D
8+
from .annotate_concat_requant import AnnotateConcatRequant
89
from .annotate_quant_attrs import AnnotateQuantAttrs
910
from .annotate_stack import AnnotateStack
1011
from .annotate_unbind import AnnotateUnbind
@@ -60,6 +61,7 @@
6061

6162
__all__ = [
6263
AnnotateAvgPool1D,
64+
AnnotateConcatRequant,
6365
AnnotateQuantAttrs,
6466
AnnotateStack,
6567
AnnotateUnbind,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict
8+
9+
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
11+
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_DTYPE,
13+
QCOM_ENCODING,
14+
QCOM_QUANT_MAX,
15+
QCOM_QUANT_MIN,
16+
QCOM_REQUANTIZE,
17+
QCOM_SCALE,
18+
QCOM_ZERO_POINT,
19+
)
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
from executorch.exir.pass_base import ExportPass, PassResult
22+
23+
from .utils import get_quant_attrs
24+
25+
26+
EDGE_CAT_OPS = {
27+
exir_ops.edge.aten.cat.default,
28+
exir_ops.edge.aten.concat.default,
29+
}
30+
31+
32+
class AnnotateConcatRequant(ExportPass):
33+
"""
34+
Record explicit requantization needs for concat inputs whose concrete
35+
post-calibration qparams do not match concat's output domain.
36+
"""
37+
38+
def __init__(
39+
self,
40+
edge_program: torch.export.ExportedProgram,
41+
skip_advanced_requant: bool = False,
42+
):
43+
super(AnnotateConcatRequant, self).__init__()
44+
self.edge_program = edge_program
45+
self.skip_advanced_requant = skip_advanced_requant
46+
47+
def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]):
48+
if self.skip_advanced_requant:
49+
return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE]
50+
51+
return any(
52+
src_attrs[attr] != dst_attrs[attr]
53+
for attr in [
54+
QCOM_SCALE,
55+
QCOM_ZERO_POINT,
56+
QCOM_QUANT_MIN,
57+
QCOM_QUANT_MAX,
58+
QCOM_DTYPE,
59+
]
60+
)
61+
62+
def _annotate_concat_input_requant(self, quant_node: torch.fx.Node) -> None:
63+
cat_node = quant_node.args[0]
64+
if cat_node.target not in EDGE_CAT_OPS:
65+
return
66+
67+
output_q_attrs = get_quant_attrs(self.edge_program, quant_node)
68+
for input_node in cat_node.args[0]:
69+
if input_node.target not in dq_ops:
70+
continue
71+
72+
source_q_node = input_node.args[0]
73+
if source_q_node.target not in q_ops:
74+
continue
75+
76+
source_q_attrs = get_quant_attrs(self.edge_program, source_q_node)
77+
if not self._is_requant_needed(source_q_attrs, output_q_attrs):
78+
continue
79+
80+
source_node = source_q_node.args[0]
81+
if not isinstance(source_node, torch.fx.Node):
82+
continue
83+
84+
requant_attrs = output_q_attrs.copy()
85+
requant_attrs[QCOM_ENCODING] = source_q_attrs[QCOM_ENCODING]
86+
source_node.meta.setdefault(QCOM_REQUANTIZE, {})
87+
source_node.meta[QCOM_REQUANTIZE][cat_node.name] = requant_attrs
88+
89+
def call(self, graph_module: torch.fx.GraphModule):
90+
for node in graph_module.graph.nodes:
91+
if (
92+
node.target in q_ops
93+
and isinstance(node.args[0], torch.fx.Node)
94+
and node.args[0].target in EDGE_CAT_OPS
95+
):
96+
self._annotate_concat_input_requant(node)
97+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
7979

8080
return last_dq_nodes
8181

82+
def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]):
83+
if self.skip_advanced_requant:
84+
return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE]
85+
86+
return any(
87+
src_attrs[attr] != dst_attrs[attr]
88+
for attr in [
89+
QCOM_SCALE,
90+
QCOM_ZERO_POINT,
91+
QCOM_QUANT_MIN,
92+
QCOM_QUANT_MAX,
93+
QCOM_DTYPE,
94+
]
95+
)
96+
8297
def _annotate_requant(self, n):
8398
# Record requant attributes:
8499
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
@@ -96,28 +111,7 @@ def _annotate_requant(self, n):
96111
# that has multiple outputs that requires quant attributes.
97112

98113
# Determine if requantization is needed based on configuration and attribute mismatch.
99-
is_requant_needed = False
100-
if self.skip_advanced_requant:
101-
# In skip_advanced_requant mode, only consider requant if dtypes differ.
102-
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
103-
is_requant_needed = True
104-
else:
105-
# In full requant mode, consider requant if any key attribute differs.
106-
# This aims to improve accuracy by adjusting scale, zero_point, etc.
107-
# Users can disable this if it causes regressions.
108-
if any(
109-
q_attrs[attr] != dq_attrs[attr]
110-
for attr in [
111-
QCOM_SCALE,
112-
QCOM_ZERO_POINT,
113-
QCOM_QUANT_MIN,
114-
QCOM_QUANT_MAX,
115-
QCOM_DTYPE,
116-
]
117-
):
118-
is_requant_needed = True
119-
120-
if is_requant_needed:
114+
if self._is_requant_needed(q_attrs, dq_attrs):
121115
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
122116
user_node = list(dq_node.users)[0]
123117
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from executorch.backends.qualcomm._passes import (
1212
AnnotateAvgPool1D,
13+
AnnotateConcatRequant,
1314
AnnotateQuantAttrs,
1415
AnnotateStack,
1516
AnnotateUnbind,
@@ -99,6 +100,7 @@ def get_capture_program_passes():
99100
default_passes_and_setting = [
100101
(AnnotateAvgPool1D, True),
101102
(AnnotateQuantAttrs, True),
103+
(AnnotateConcatRequant, True),
102104
(AnnotateStack, True),
103105
(AnnotateUnbind, True),
104106
(ConvertBmmToMatmul, False),

backends/qualcomm/_passes/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def get_passes_dependency_for_capture_program():
6161
"""
6262
from executorch.backends.qualcomm._passes import (
6363
AnnotateAvgPool1D,
64+
AnnotateConcatRequant,
6465
AnnotateQuantAttrs,
6566
AnnotateStack,
6667
AnnotateUnbind,
@@ -89,6 +90,7 @@ def get_passes_dependency_for_capture_program():
8990

9091
return {
9192
AnnotateAvgPool1D: [RemoveRedundancy],
93+
AnnotateConcatRequant: [AnnotateQuantAttrs],
9294
AnnotateQuantAttrs: [
9395
ConvertBmmToMatmul,
9496
RecomposePixelUnshuffle,
@@ -108,9 +110,15 @@ def get_passes_dependency_for_capture_program():
108110
DecomposeTrunc: [RemoveRedundancy],
109111
ExpandBroadcastTensorShape: [FoldQDQ],
110112
FixedLinearKeepDim: [FoldQDQ],
111-
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
113+
FoldQDQ: [
114+
AnnotateConcatRequant,
115+
AnnotateQuantAttrs,
116+
AnnotateStack,
117+
AnnotateUnbind,
118+
],
112119
I64toI32: [RemoveRedundancy],
113120
LayoutTransform: [
121+
AnnotateConcatRequant,
114122
AnnotateQuantAttrs,
115123
ExpandBroadcastTensorShape,
116124
FixedLinearKeepDim,

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants
1414
import torch
15-
1615
from executorch.backends.qualcomm.quantizer.observers.concat_observer import (
1716
ConcatObserver,
1817
)
@@ -235,31 +234,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
235234
return
236235

237236
input_qspec_map, input_nodes = {}, node.args[0]
238-
for input in input_nodes:
239-
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
237+
for input_node in input_nodes:
238+
assert isinstance(input_node, Node)
239+
input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None)
240240
qspec = getattr(input_qspec, "output_qspec", None)
241-
# keep shared qspec here for propagation the data range
242-
# without introducing extra requantizations
241+
# Preserve shared upstream qspecs, but derive concat's output domain
242+
# from the merged output range to avoid clipping wider branches.
243243
if isinstance(qspec, SharedQuantizationSpec):
244-
input_qspec_map[input] = SharedQuantizationSpec(input)
244+
input_qspec_map[input_node] = SharedQuantizationSpec(input_node)
245245
else:
246-
input_qspec_map[input] = quantization_config.input_activation
246+
input_qspec_map[input_node] = quantization_config.input_activation
247247

248248
output_qspec = QuantizationSpec(
249249
dtype=quantization_config.output_activation.dtype,
250250
qscheme=quantization_config.output_activation.qscheme,
251251
quant_max=quantization_config.output_activation.quant_max,
252252
quant_min=quantization_config.output_activation.quant_min,
253253
observer_or_fake_quant_ctr=ConcatObserver.with_args(
254-
# we need to know the concat node in order to hack all the input observers' data range
255-
# since deep copy of fake tensor (node.meta["val"]) is inhibited
256-
# we could only ship grap & node name and perform postprocess inside observer currently
257-
**{
258-
"node_name": node.name,
259-
"graph": node.graph,
260-
}
254+
node_name=node.name,
255+
graph=node.graph,
261256
),
262257
)
258+
263259
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
264260
input_qspec_map=input_qspec_map,
265261
output_qspec=output_qspec,
@@ -295,6 +291,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
295291
@register_annotator(
296292
[
297293
torch.ops.aten.split_with_sizes.default,
294+
torch.ops.aten.split_with_sizes_copy.default,
298295
torch.ops.aten.split.Tensor,
299296
torch.ops.aten.chunk.default,
300297
],
@@ -1203,14 +1200,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
12031200
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
12041201
)
12051202
class PixelShuffle(GeneralOpDef):
1206-
pass
1203+
@staticmethod
1204+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
1205+
annotate_in_out_obs_sharing_op(node, quantization_config)
1206+
if not _is_annotated([node]):
1207+
annotate_single_in_share_out(node, quantization_config)
12071208

12081209

12091210
@register_annotator(
12101211
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
12111212
)
12121213
class PixelUnshuffle(GeneralOpDef):
1213-
pass
1214+
@staticmethod
1215+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
1216+
annotate_in_out_obs_sharing_op(node, quantization_config)
1217+
if not _is_annotated([node]):
1218+
annotate_single_in_share_out(node, quantization_config)
12141219

12151220

12161221
@register_annotator(

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants
1313
import torch
14-
1514
from executorch.backends.qualcomm.quantizer.observers.concat_observer import (
1615
ConcatObserver,
1716
)
@@ -181,31 +180,26 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
181180
return
182181

183182
input_qspec_map, input_nodes = {}, node.args[0]
184-
for input in input_nodes:
185-
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
183+
for input_node in input_nodes:
184+
assert isinstance(input_node, Node)
185+
input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None)
186186
qspec = getattr(input_qspec, "output_qspec", None)
187-
# keep shared qspec here for propagation the data range
188-
# without introducing extra requantizations
189187
if isinstance(qspec, SharedQuantizationSpec):
190-
input_qspec_map[input] = SharedQuantizationSpec(input)
188+
input_qspec_map[input_node] = SharedQuantizationSpec(input_node)
191189
else:
192-
input_qspec_map[input] = quantization_config.input_activation
190+
input_qspec_map[input_node] = quantization_config.input_activation
193191

194192
output_qspec = QuantizationSpec(
195193
dtype=quantization_config.output_activation.dtype,
196194
qscheme=quantization_config.output_activation.qscheme,
197195
quant_max=quantization_config.output_activation.quant_max,
198196
quant_min=quantization_config.output_activation.quant_min,
199197
observer_or_fake_quant_ctr=ConcatObserver.with_args(
200-
# we need to know the concat node in order to hack all the input observers' data range
201-
# since deep copy of fake tensor (node.meta["val"]) is inhibited
202-
# we could only ship grap & node name and perform postprocess inside observer currently
203-
**{
204-
"node_name": node.name,
205-
"graph": node.graph,
206-
}
198+
node_name=node.name,
199+
graph=node.graph,
207200
),
208201
)
202+
209203
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
210204
input_qspec_map=input_qspec_map,
211205
output_qspec=output_qspec,
@@ -223,6 +217,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
223217
@register_annotator(
224218
[
225219
torch.ops.aten.split_with_sizes.default,
220+
torch.ops.aten.split_with_sizes_copy.default,
226221
torch.ops.aten.split.Tensor,
227222
torch.ops.aten.chunk.default,
228223
],
@@ -705,14 +700,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
705700
[torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name
706701
)
707702
class PixelShuffle(GeneralOpDef):
708-
pass
703+
@staticmethod
704+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
705+
annotate_in_out_obs_sharing_op(node, quantization_config)
706+
if not _is_annotated([node]):
707+
annotate_single_in_share_out(node, quantization_config)
709708

710709

711710
@register_annotator(
712711
[torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name
713712
)
714713
class PixelUnshuffle(GeneralOpDef):
715-
pass
714+
@staticmethod
715+
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
716+
annotate_in_out_obs_sharing_op(node, quantization_config)
717+
if not _is_annotated([node]):
718+
annotate_single_in_share_out(node, quantization_config)
716719

717720

718721
@register_annotator(

0 commit comments

Comments
 (0)