Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'abcd_normalize', 'bessel', 'besselap', 'envelope', 'get_window', 'lfilter_zi',
26 'sosfilt_zi', 'remez',
27]
28
29# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
30CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
31
32
33def delegate_xp(delegator, module_name):
34 def inner(func):
35 @functools.wraps(func)
36 def wrapper(*args, **kwds):
37 try:
38 xp = delegator(*args, **kwds)
39 except TypeError:
40 # object arrays
41 if func.__name__ == "tf2ss":
42 import numpy as np
43 xp = np
44 else:
45 raise
46
47 # try delegating to a cupyx/jax namesake
48 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
49 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
50
51 # https://github.com/cupy/cupy/issues/8336
52 import importlib
53 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
54 cupyx_func = getattr(cupyx_module, func_name)
55 kwds.pop('xp', None)
56 return cupyx_func(*args, **kwds)
57 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
58 spx = scipy_namespace_for(xp)
59 jax_module = getattr(spx, module_name)
60 jax_func = getattr(jax_module, func.__name__)
61 kwds.pop('xp', None)
62 return jax_func(*args, **kwds)
63 else:
64 # the original function
65 return func(*args, **kwds)
66 return wrapper
67 return inner
68
69
70# Although most of these functions currently exist in CuPy and some in JAX,
71# there are no alternative backend tests for any of them in the current
72# test suite. Each will be documented as np_only until tests are added.
73untested = {
74 "argrelextrema",
75 "argrelmax",
76 "argrelmin",
77 "band_stop_obj",
78 "bode",
79 "check_NOLA",
80 "chirp",
81 "coherence",
82 "csd",
83 "czt",
84 "czt_points",
85 "dbode",
86 "dfreqresp",
87 "dlsim",
88 "dstep",
89 "find_peaks",
90 "find_peaks_cwt",
91 "freqresp",
92 "gausspulse",
93 "iirdesign", # There's no reason this shouldn't work. It just needs tests.
94 "istft",
95 "lombscargle",
96 "lsim",
97 "max_len_seq",
98 "peak_prominences",
99 "peak_widths",
100 "periodogram",
101 "place_poles",
102 "sawtooth",
103 "sepfir2d",
104 "ss2tf",
105 "ss2zpk",
106 "step",
107 "sweep_poly",
108 "symiirorder1",
109 "symiirorder2",
110 "tf2ss",
111 "unit_impulse",
112 "welch",
113 "zoom_fft",
114 "zpk2ss",
115}
116
117
118def get_default_capabilities(func_name, delegator):
119 if delegator is None or func_name in untested:
120 return xp_capabilities(np_only=True)
121 return xp_capabilities()
122
123bilinear_extra_note = \
124 """CuPy does not accept complex inputs.
125
126 """
127
128uses_choose_conv_extra_note = \
129 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
130 but does support higher dimensional arrays for ``method="direct"``
131 and ``method="fft"``.
132
133 """
134
135resample_poly_extra_note = \
136 """CuPy only supports ``padtype="constant"``.
137
138 """
139
140upfirdn_extra_note = \
141 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
142
143 """
144
145xord_extra_note = \
146 """The ``torch`` backend on GPU does not support the case where
147 `wp` and `ws` specify a Bandstop filter.
148
149 """
150
151convolve2d_extra_note = \
152 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
153
154 """
155
156zpk2tf_extra_note = \
157 """The CuPy and JAX backends both support only 1d input.
158
159 """
160
161abcd_normalize_extra_note = \
162 """The result dtype when all array inputs are of integer dtype is the
163 backend's current default floating point dtype.
164
165 """
166
167capabilities_overrides = {
168 "abcd_normalize": xp_capabilities(extra_note=abcd_normalize_extra_note),
169 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
170 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
171 jax_jit=False, allow_dask_compute=True,
172 reason="Uses np.polynomial.Polynomial",
173 extra_note=bilinear_extra_note),
174 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
175 jax_jit=False, allow_dask_compute=True),
176 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
177 allow_dask_compute=True),
178 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
179 jax_jit=False, allow_dask_compute=True,
180 extra_note=xord_extra_note),
181 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
182 jax_jit=False, allow_dask_compute=True,
183 extra_note=xord_extra_note),
184 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
185 jax_jit=False, allow_dask_compute=True,
186 extra_note=xord_extra_note),
187 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
188 allow_dask_compute=True),
189
190 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
191 allow_dask_compute=True),
192 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
193 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
194 allow_dask_compute=True,
195 extra_note=uses_choose_conv_extra_note),
196 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
197 allow_dask_compute=True,
198 extra_note=convolve2d_extra_note),
199 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
200 allow_dask_compute=True,
201 extra_note=uses_choose_conv_extra_note),
202 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
203 allow_dask_compute=True,
204 extra_note=convolve2d_extra_note),
205 "correlation_lags": xp_capabilities(out_of_scope=True),
206 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
207 jax_jit=False, allow_dask_compute=True),
208 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
209 jax_jit=False, allow_dask_compute=True),
210 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
211 jax_jit=False, allow_dask_compute=True),
212 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
213 allow_dask_compute=True,
214 skip_backends=[("jax.numpy", "item assignment")]),
215 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
216 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
217 allow_dask_compute=True),
218 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
219 "dlti": xp_capabilities(np_only=True,
220 reason="works in CuPy but delegation isn't set up yet"),
221 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
222 allow_dask_compute=True,
223 reason="scipy.special.ellipk"),
224 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
225 jax_jit=False, allow_dask_compute=True,
226 reason="scipy.special.ellipk"),
227 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
228 jax_jit=False, allow_dask_compute=True),
229 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
230 reason="lstsq"),
231 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
232 jax_jit=False, allow_dask_compute=True),
233 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
234 jax_jit=False, allow_dask_compute=True,
235 reason="firwin uses np.interp"),
236 "fftconvolve": xp_capabilities(cpu_only=True,
237 exceptions=["cupy", "jax.numpy", "torch"]),
238 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
239 jax_jit=False, allow_dask_compute=True),
240 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
241 jax_jit=False, allow_dask_compute=True),
242 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
243 jax_jit=False, allow_dask_compute=True),
244 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
245 jax_jit=False, allow_dask_compute=True),
246 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
247 jax_jit=False, allow_dask_compute=True),
248 "hilbert": xp_capabilities(
249 cpu_only=True, exceptions=["cupy", "torch"],
250 skip_backends=[("jax.numpy", "item assignment")],
251 ),
252 "hilbert2": xp_capabilities(
253 cpu_only=True, exceptions=["cupy", "torch"],
254 skip_backends=[("jax.numpy", "item assignment")],
255 ),
256 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
257 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
258 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
259 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
260 jax_jit=False, allow_dask_compute=True),
261 "kaiser_atten": xp_capabilities(
262 out_of_scope=True, reason="scalars in, scalars out"
263 ),
264 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
265 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
266 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
267 allow_dask_compute=True, jax_jit=False),
268 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
269 jax_jit=False),
270 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
271 allow_dask_compute=True),
272 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
273 allow_dask_compute=True,
274 skip_backends=[("jax.numpy", "in-place item assignment")]),
275 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
276 allow_dask_compute=True, jax_jit=False),
277 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
278 allow_dask_compute=True,
279 skip_backends=[("jax.numpy", "in-place item assignment")]),
280 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
281 allow_dask_compute=True, jax_jit=False),
282 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
283 allow_dask_compute=True, jax_jit=False),
284 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
285 allow_dask_compute=True, jax_jit=False),
286 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
287 allow_dask_compute=True,
288 skip_backends=[("jax.numpy", "in-place item assignment")]),
289 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
290 allow_dask_compute=True, jax_jit=False),
291 "lti": xp_capabilities(np_only=True,
292 reason="works in CuPy but delegation isn't set up yet"),
293 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
294 allow_dask_compute=True, jax_jit=False,
295 reason="uses scipy.ndimage.rank_filter"),
296 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
297 allow_dask_compute=True, jax_jit=False,
298 reason="c extension module"),
299 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
300 allow_dask_compute=True, jax_jit=False),
301 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
302 jax_jit=False, allow_dask_compute=True),
303 "oaconvolve": xp_capabilities(
304 cpu_only=True, exceptions=["cupy", "torch"],
305 skip_backends=[("jax.numpy", "fails all around")],
306 xfail_backends=[("dask.array", "wrong answer")],
307 ),
308 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
309 allow_dask_compute=True, jax_jit=False,
310 reason="uses scipy.ndimage.rank_filter"),
311 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
312 jax_jit=False, allow_dask_compute=True),
313 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
314 jax_jit=False, allow_dask_compute=True),
315 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
316 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
317 "resample": xp_capabilities(
318 cpu_only=True, exceptions=["cupy"],
319 skip_backends=[
320 ("dask.array", "XXX something in dask"),
321 ("jax.numpy", "XXX: immutable arrays"),
322 ]
323 ),
324 "resample_poly": xp_capabilities(
325 cpu_only=True, exceptions=["cupy"],
326 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
327 extra_note=resample_poly_extra_note,
328 ),
329 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
330 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
331 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
332 jax_jit=False,
333 reason="convolve1d is cpu-only"),
334 "sawtooth": xp_capabilities(jax_jit=False,
335 skip_backends=[("dask.array", "dask tests fail")]),
336 "sepfir2d": xp_capabilities(np_only=True),
337 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
338 allow_dask_compute=True),
339 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
340 allow_dask_compute=True),
341 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
342 allow_dask_compute=True),
343 "sosfilt_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
344 jax_jit=False),
345 "sosfiltfilt": xp_capabilities(
346 cpu_only=True, exceptions=["cupy"],
347 skip_backends=[
348 (
349 "dask.array",
350 "sosfiltfilt directly sets shape attributes on arrays"
351 " which dask doesn't like"
352 ),
353 ("torch", "negative strides"),
354 ("jax.numpy", "sosfilt works in-place"),
355 ],
356 ),
357 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
358 jax_jit=False, allow_dask_compute=True),
359 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
360 jax_jit=False, allow_dask_compute=True),
361 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
362 allow_dask_compute=True),
363 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
364 allow_dask_compute=True),
365 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
366 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
367 allow_dask_compute=True,
368 reason="Cython implementation",
369 extra_note=upfirdn_extra_note),
370 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
371 allow_dask_compute=True, jax_jit=False),
372 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
373 allow_dask_compute=True, jax_jit=False,
374 reason="uses scipy.signal.correlate"),
375 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
376 allow_dask_compute=True),
377 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
378 allow_dask_compute=True,
379 extra_note=zpk2tf_extra_note),
380 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
381 "stft": xp_capabilities(out_of_scope=True), # legacy
382 "istft": xp_capabilities(out_of_scope=True), # legacy
383 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
384}
385
386
387# ### decorate ###
388for obj_name in _signal_api.__all__:
389 bare_obj = getattr(_signal_api, obj_name)
390 delegator = getattr(_delegators, obj_name + "_signature", None)
391
392 if SCIPY_ARRAY_API and delegator is not None:
393 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
394 else:
395 f = bare_obj
396
397 if not isinstance(f, types.ModuleType):
398 capabilities = capabilities_overrides.get(
399 obj_name, get_default_capabilities(obj_name, delegator)
400 )
401 f = capabilities(f)
402
403 # add the decorated function to the namespace, to be imported in __init__.py
404 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
339def firwin_signature(numtaps, cutoff, *args, **kwds):
340 if isinstance(cutoff, int | float):
341 xp = np_compat
342 else:
343 xp = array_namespace(cutoff)
344 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
torch |
jax |
dask |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
cupy |
torch |
jax |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
jax |
|---|---|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |