Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion packages/transformers/src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,43 @@ async function ensureWasmLoaded() {
return wasmLoadPromise;
}

/**
* Tracks providers we've seen fail to initialize so subsequent sessions don't
* pay the same cost twice (e.g. on Linux x64 boxes without a CUDA runtime,
* `cuda` is on the auto-fallback chain but throws on first use).
*
* @type {Set<string>}
*/
const failedExecutionProviders = new Set();

/**
* Extracts the provider name from an ONNX Runtime "failed to append provider"
* style error. Returns the lowercase provider name (matching the names used in
* `DEVICE_TO_EXECUTION_PROVIDER_MAPPING`) or `null` if the error doesn't look
* like a provider-load failure.
*
* @param {unknown} error
* @returns {string|null}
*/
function detectFailedExecutionProvider(error) {
// @ts-ignore — we only read .message if present.
const message = String(error?.message ?? error ?? '');
const match = message.match(/OrtSessionOptionsAppendExecutionProvider_(\w+)/);
if (!match) return null;
return match[1].toLowerCase();
}

/**
* Returns the lowercase name of an execution provider entry, whether it's a
* bare string (`'cuda'`) or an object (`{ name: 'webnn', deviceType: 'cpu' }`).
*
* @param {ONNXExecutionProviders} provider
* @returns {string}
*/
function executionProviderName(provider) {
return typeof provider === 'string' ? provider.toLowerCase() : provider.name.toLowerCase();
}

/**
* Create an ONNX inference session.
* @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path.
Expand All @@ -286,13 +323,60 @@ async function ensureWasmLoaded() {
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
await ensureWasmLoaded();
const logSeverityLevel = getOnnxLogSeverityLevel(env.logLevel ?? LogLevel.WARNING);

// Drop providers we already know don't work on this host (e.g. `cuda` on a
// box without the CUDA runtime). Only applies when `executionProviders` is
// a list — leave any caller-supplied non-list value untouched.
if (Array.isArray(session_options.executionProviders) && failedExecutionProviders.size > 0) {
const filtered = session_options.executionProviders.filter(
(p) => !failedExecutionProviders.has(executionProviderName(p)),
);
if (filtered.length > 0 && filtered.length !== session_options.executionProviders.length) {
session_options = { ...session_options, executionProviders: filtered };
}
}

const load = () =>
InferenceSession.create(buffer_or_path, {
// Set default log severity level, but allow overriding through session options
logSeverityLevel,
...session_options,
});
const session = await (apis.IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load());

/** @type {import('onnxruntime-common').InferenceSession} */
let session;
try {
session = await (apis.IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load());
} catch (error) {
// If session creation fails because an execution provider's shared
// library didn't load (e.g. CUDA on a host without the runtime), drop
// it from the list and retry. ONNX Runtime fails the whole session
// when ANY provider fails to register, so without this the entire
// auto-fallback chain is wasted: a Linux x64 box with no CUDA install
// hits the cuda branch first and never reaches webgpu/cpu.
const failed = detectFailedExecutionProvider(error);
const providers = session_options.executionProviders;
if (
failed &&
Array.isArray(providers) &&
providers.length > 1 &&
providers.some((p) => executionProviderName(p) === failed)
) {
failedExecutionProviders.add(failed);
const remaining = providers.filter((p) => executionProviderName(p) !== failed);
logger.warn(
`Execution provider '${failed}' is unavailable on this host (${/** @type {Error} */ (error).message}). ` +
`Falling back to: [${remaining.map(executionProviderName).join(', ')}].`,
);
return createInferenceSession(
buffer_or_path,
{ ...session_options, executionProviders: remaining },
session_config,
);
}
throw error;
}

session.config = session_config;
return session;
}
Expand Down