-
Notifications
You must be signed in to change notification settings - Fork 4
Sharding: "Hip error: 'operation would make the legacy stream ..." #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hello @ppanchad-amd , is there any news about this issue? It's a total blocker for us. |
Hi @PhilipVinc, I am investigating this, could you clarify if you are still using Jax 0.4.35 for this, or 0.5.0? |
Hi @lucbruni-amd I was using jax 0.4.35 because of the issues in #268. The code still errors with jax 0.5.0, but the error is different. See below [cad14908] fvicentini@a1018:~/rep2$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
- NETKET_DEBUG = False
- NETKET_EXPERIMENTAL = False
- NETKET_MPI_WARNING = True
- NETKET_MPI = False
- NETKET_USE_PLAIN_RHAT = False
- NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
- NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
- NETKET_EXPERIMENTAL_SHARDING_CPU = 0
- NETKET_ENABLE_X64 = True
- NETKET_SPHINX_BUILD = False
- NETKET_EXPERIMENTAL_SHARDING = True
- NETKET_MPI_AUTODETECT_LOCAL_GPU = False
- NETKET_RANDOM_STATE_FALLBACK_WARNING = True
- NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
- NETKET_SPIN_ORDERING_WARNING = True
jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
0%| | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/lus/home/CT5/cad14908/fvicentini/rep2/simple_nocnn.py", line 44, in <module>
gs.run(n_iter=500, out=log, timeit=True)
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/abstract_variational_driver.py", line 347, in run
for step in self.iter(n_iter, step_size):
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/abstract_variational_driver.py", line 232, in iter
self._dp = self._forward_and_backward()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
result = fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/vmc.py", line 122, in _forward_and_backward
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
result = fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 710, in expect_and_grad
return expect_and_grad(
^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
return _convert(method(*args, **kw_args), return_type)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect_grad.py", line 58, in expect_and_grad_default_formula
Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, *args, mutable=mutable)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
return _convert(method(*args, **kw_args), return_type)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect_forces.py", line 48, in expect_and_forces
σ, args = get_local_kernel_arguments(vstate, Ô)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
return _convert(method(*args, **kw_args), return_type)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect.py", line 61, in get_local_kernel_arguments
σ = vstate.samples
^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 611, in samples
self.sample()
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
result = fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 575, in sample
self.sampler_state = self.sampler.reset(
^^^^^^^^^^^^^^^^^^^
File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/sampler/base.py", line 276, in reset
return sampler._reset(wrap_afun(machine), parameters, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(98)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. and the jax version is [cad14908] fvicentini@a1018:~/rep2$ uv tree | grep jax
Resolved 46 packages in 4ms
├── jax[rocm] v0.5.0
│ ├── jaxlib v0.5.0
│ ├── jax-rocm60-plugin v0.5.0 (extra: rocm)
│ │ └── jax-rocm60-pjrt v0.5.0
│ └── jaxlib v0.5.0 (extra: rocm) (*)
├── jax-rocm60-pjrt v0.5.0
├── jax-rocm60-plugin v0.5.0 (*)
├── jaxlib v0.5.0 (*)
│ ├── jax v0.5.0 (*)
│ ├── jaxtyping v0.3.0
│ ├── jax v0.5.0 (*)
│ │ │ ├── jax v0.5.0 (*)
│ │ │ ├── jaxlib v0.5.0 (*)
│ │ ├── jax v0.5.0 (*)
│ │ ├── jaxlib v0.5.0 (*)
│ │ ├── jax v0.5.0 (*)
├── jax v0.5.0 (*) |
Thanks for the helpful info. I'll respond here when I have updates on this. |
Hi @PhilipVinc, Just wanted to let you know I'm able to consistently reproduce your latter error |
Thank you for the update. |
Hi @lucbruni-amd is there any update on your side? I just tried rerunning the example from this bug report on And the outcome is again different. export NETKET_ENABLE_X64=1
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0
[cad14908] fvicentini@a1004:~/rep4$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
- NETKET_DEBUG = False
- NETKET_EXPERIMENTAL = False
- NETKET_MPI_WARNING = True
- NETKET_MPI = False
- NETKET_USE_PLAIN_RHAT = False
- NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
- NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
- NETKET_EXPERIMENTAL_SHARDING_CPU = 0
- NETKET_ENABLE_X64 = True
- NETKET_SPHINX_BUILD = False
- NETKET_EXPERIMENTAL_SHARDING = True
- NETKET_MPI_AUTODETECT_LOCAL_GPU = False
- NETKET_RANDOM_STATE_FALLBACK_WARNING = True
- NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
- NETKET_SPIN_ORDERING_WARNING = True
jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
0%| | 0/500 [00:00<?, ?it/s]2025-04-11 17:35:12.675020: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 2; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-04-11 17:35:12.676137: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 3; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-04-11 17:35:12.676165: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 1; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback. Do you have any timeline for a fix? this is a 100% blocker for us |
It seems that it is an x64 (double precision) issue, as setting export NETKET_ENABLE_X64=0
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0 makes the code run. |
Description
I am running on 'bare metal' jax/lib 0.4.35 installed from the releases in this repository.
I had to work around the missing
libsuitesparseconfig.so.4
by manually installing it.When running a relatively straightforward script that uses sharding on a single process addressing 4 local GPUS I get weird errors.
The script and pyproject file can be found at this gist.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: