55
66import torch
77from deepspeed .accelerator .abstract_accelerator import DeepSpeedAccelerator
8- import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9- import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
108import functools
11-
129import importlib
1310import inspect
1411
12+ try :
13+ import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14+ oneccl_imported_p = True
15+ except ImportError as e :
16+ oneccl_imported_p = False
17+
18+ try :
19+ import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20+ ipex_imported_p = True
21+ except ImportError as e :
22+ ipex_imported_p = False
1523
1624class XPU_Accelerator (DeepSpeedAccelerator ):
1725
1826 def __init__ (self ):
1927 self ._name = 'xpu'
2028 self ._communication_backend_name = 'ccl'
29+ if oneccl_imported_p :
30+ self ._communication_backend_name = 'ccl'
31+ else :
32+ # changed to xccl if not using torch-CCL on XPU device
33+ self ._communication_backend_name = 'xccl'
2134 self ._compile_backend = "inductor"
2235 self .aligned_tensors = []
2336 self .class_dict = None
@@ -26,11 +39,14 @@ def is_synchronized_device(self):
2639 return False
2740
2841 def use_host_timers (self ):
29- # WA XPU event will be consolidated in 2.6
30- if ipex .__version__ < '2.6' :
31- return True
32- else :
42+ if not ipex_imported_p :
3343 return self .is_synchronized_device ()
44+ else :
45+ # WA XPU event will be consolidated in 2.6
46+ if ipex .__version__ < '2.6' :
47+ return True
48+ else :
49+ return self .is_synchronized_device ()
3450
3551 def resolves_data_dependency (self ):
3652 return self .is_synchronized_device ()
@@ -290,10 +306,13 @@ def get_op_builder(self, class_name):
290306 return self .class_dict ['NotImplementedBuilder' ]
291307
292308 def build_extension (self ):
293- try :
294- from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
295- except ImportError :
296- from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
309+ if not ipex_imported_p :
310+ try :
311+ from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
312+ except ImportError :
313+ from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
314+ else :
315+ from torch .utils .cpp_extension import DpcppBuildExtension
297316 return DpcppBuildExtension
298317
299318 def export_envs (self ):
0 commit comments