I have a .pkl
file I downloaded from a public GitHub repository and when I read it using pickle.load
, i.e. using
with open('filename.pkl'), 'rb') as f:
file_content = pickle.load(f)
I get the following error
ModuleNotFoundError: No module named 'jax._src.device_array'
From e.g. this StackOverflow question I understand this is an issue with jax
versions. Specifically, the .pkl
file must have been created with a jax
version <0.4
, while I am currently using v0.4.31
.
I then proceed to create a separate conda
environment installing jax=0.3.25
, following @jakevdp ‘s answer in that StackOverflow question, and indeed I am able to load the .pkl
file. Following again @jakevdp ‘s advice, I proceed to save the content of the file using jax.numpy.save
, as follows:
jnp.save('filename.npy', file_content)
which I can then read with
file_content = jnp.load('filename.npy', allow_pickle=True)
However, the problem now is that I need to use the content of this file with a jax
v0.4.31
, given the constraints of my specific use case. This is a problem because I am then back to the same issue as above, namely that the DeviceArray is not recognised.
How can I convert the (many) DeviceArrays in the original file into Arrays for the newer version of jax
?