|
12 | 12 |
|
13 | 13 | import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants |
14 | 14 | import torch |
15 | | - |
16 | 15 | from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( |
17 | 16 | ConcatObserver, |
18 | 17 | ) |
@@ -235,31 +234,28 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: |
235 | 234 | return |
236 | 235 |
|
237 | 236 | input_qspec_map, input_nodes = {}, node.args[0] |
238 | | - for input in input_nodes: |
239 | | - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) |
| 237 | + for input_node in input_nodes: |
| 238 | + assert isinstance(input_node, Node) |
| 239 | + input_qspec = input_node.meta.get(Q_ANNOTATION_KEY, None) |
240 | 240 | qspec = getattr(input_qspec, "output_qspec", None) |
241 | | - # keep shared qspec here for propagation the data range |
242 | | - # without introducing extra requantizations |
| 241 | + # Preserve shared upstream qspecs, but derive concat's output domain |
| 242 | + # from the merged output range to avoid clipping wider branches. |
243 | 243 | if isinstance(qspec, SharedQuantizationSpec): |
244 | | - input_qspec_map[input] = SharedQuantizationSpec(input) |
| 244 | + input_qspec_map[input_node] = SharedQuantizationSpec(input_node) |
245 | 245 | else: |
246 | | - input_qspec_map[input] = quantization_config.input_activation |
| 246 | + input_qspec_map[input_node] = quantization_config.input_activation |
247 | 247 |
|
248 | 248 | output_qspec = QuantizationSpec( |
249 | 249 | dtype=quantization_config.output_activation.dtype, |
250 | 250 | qscheme=quantization_config.output_activation.qscheme, |
251 | 251 | quant_max=quantization_config.output_activation.quant_max, |
252 | 252 | quant_min=quantization_config.output_activation.quant_min, |
253 | 253 | observer_or_fake_quant_ctr=ConcatObserver.with_args( |
254 | | - # we need to know the concat node in order to hack all the input observers' data range |
255 | | - # since deep copy of fake tensor (node.meta["val"]) is inhibited |
256 | | - # we could only ship grap & node name and perform postprocess inside observer currently |
257 | | - **{ |
258 | | - "node_name": node.name, |
259 | | - "graph": node.graph, |
260 | | - } |
| 254 | + node_name=node.name, |
| 255 | + graph=node.graph, |
261 | 256 | ), |
262 | 257 | ) |
| 258 | + |
263 | 259 | node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( |
264 | 260 | input_qspec_map=input_qspec_map, |
265 | 261 | output_qspec=output_qspec, |
@@ -295,6 +291,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: |
295 | 291 | @register_annotator( |
296 | 292 | [ |
297 | 293 | torch.ops.aten.split_with_sizes.default, |
| 294 | + torch.ops.aten.split_with_sizes_copy.default, |
298 | 295 | torch.ops.aten.split.Tensor, |
299 | 296 | torch.ops.aten.chunk.default, |
300 | 297 | ], |
@@ -1203,14 +1200,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: |
1203 | 1200 | [torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name |
1204 | 1201 | ) |
1205 | 1202 | class PixelShuffle(GeneralOpDef): |
1206 | | - pass |
| 1203 | + @staticmethod |
| 1204 | + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: |
| 1205 | + annotate_in_out_obs_sharing_op(node, quantization_config) |
| 1206 | + if not _is_annotated([node]): |
| 1207 | + annotate_single_in_share_out(node, quantization_config) |
1207 | 1208 |
|
1208 | 1209 |
|
1209 | 1210 | @register_annotator( |
1210 | 1211 | [torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name |
1211 | 1212 | ) |
1212 | 1213 | class PixelUnshuffle(GeneralOpDef): |
1213 | | - pass |
| 1214 | + @staticmethod |
| 1215 | + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: |
| 1216 | + annotate_in_out_obs_sharing_op(node, quantization_config) |
| 1217 | + if not _is_annotated([node]): |
| 1218 | + annotate_single_in_share_out(node, quantization_config) |
1214 | 1219 |
|
1215 | 1220 |
|
1216 | 1221 | @register_annotator( |
|
0 commit comments