致谢师兄的 jax 环境,完全按照师兄的 conda_env.yml 配置的
(如何导出其他环境的 conda_env.yml:Conda | 如何(在新服务器上)复制一份旧服务器的 conda 环境,Linux 服务器)
目录
- 01 安装各种库
- 02 安装 jax
- 03 安装 dm_control metaworld d4rl
- 04 测试
- 05 各种库的参考版本
首先,新建一个 conda 环境:
conda create -n jax_env python==3.8
conda activate jax_env
(如何配置 conda:Conda | 如何在 Linux 服务器安装 conda)
01 安装各种库
直接 pip 安装:
pip install numpy==1.21.6 torch==1.13.1 wandb==0.15.10 \
transformers==4.30.2 typing-extensions==4.7.1 optax==0.1.4 \
jax==0.3.24 flax==0.6.0 cloudpickle==2.2.1 distrax==0.1.3 \
glfw==2.6.2 gym==0.15.7
02 安装 jax
jax 把自己的库放在了网站上:
- https://storage.googleapis.com/jax-releases/jax_releases.html
- https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
要安装 0.3.24 的 jax,可以运行:
pip install "jax[cuda11_cudnn82]==0.3.24" \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
需要注意:
- jax jaxlib optax flax 等库,它们的版本有对应关系,可按照这篇博客的参考版本安装;
- 需要 pip install cloudpickle==2.2.1,好像很容易安装成 1.2.2 版本,最后要检查一下版本;
- 编译的时候,因为 ptxas 版本太低报错,可以运行 which ptxas,查看现在在用哪个 ptxas 版本。如果发现在用老 cuda 版本,则去改 path,修改 ~/.bashrc,添加
export PATH="/usr/local/cuda-{版本号}/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-{版本号}/lib64:$LD_LIBRARY_PATH"
# cuda 版本号可以看 /usr/local 目录里有哪些版本,我用的是 11.7
03 安装 dm_control metaworld d4rl
需要先安装 MuJoCo,可参见这篇:Python · MuJoCo | MuJoCo 与 mujoco_py 的版本对应,以及安装 Cython<3
先把 dm_control metaworld d4rl 这三个库拿下来:
git clone git@github.com:Farama-Foundation/Metaworld.git
git clone git@github.com:Farama-Foundation/D4RL.git
git clone git@github.com:denisyarats/dmc2gym.git
然后分别进入它们的路径,执行 pip install -e . 即可。
04 测试
我跑的是 https://github.com/csmile-1006/PreferenceTransformer 这个库,它里面也有 IQL 的 jax 实现,所以这个环境应该是能跑 IQL jax 的)
05 各种库的参考版本
以下是一个参考环境的版本:
name: jax_env
channels:- defaults
dependencies:- _libgcc_mutex=0.1=main- ca-certificates=2023.08.22=h06a4308_0- certifi=2022.12.7=py37h06a4308_0- ld_impl_linux-64=2.38=h1181459_1- libffi=3.3=he6710b0_2- libgcc-ng=9.1.0=hdf63c60_0- libstdcxx-ng=9.1.0=hdf63c60_0- ncurses=6.3=h7f8727e_2- openssl=1.1.1w=h7f8727e_0- pip=22.3.1=py37h06a4308_0- python=3.7.13=h12debd9_0- readline=8.1.2=h7f8727e_1- setuptools=65.6.3=py37h06a4308_0- sqlite=3.38.5=hc218d9a_0- tk=8.6.12=h1ccaba5_0- wheel=0.38.4=py37h06a4308_0- xz=5.2.5=h7f8727e_1- zlib=1.2.12=h7f8727e_2- pip:- absl-py==1.4.0- appdirs==1.4.4- beautifulsoup4==4.12.2- cffi==1.15.1- charset-normalizer==3.2.0- chex==0.1.5- click==8.1.7- cloudpickle==2.2.1- colorama==0.4.6- commonmark==0.9.1- contextlib2==21.6.0- cycler==0.11.0- cython==3.0.2- decorator==5.1.1- distrax==0.1.3- dm-control==1.0.13- dm-env==1.6- dm-tree==0.1.8- docker-pycreds==0.4.0- etils==0.9.0- fasteners==0.18- filelock==3.12.2- flax==0.6.0- fonttools==4.38.0- fsspec==2023.1.0- future==0.18.3- gast==0.5.4- gdown==4.7.1- gitdb==4.0.10- gitpython==3.1.36- glfw==2.6.2- gym==0.15.7- gym-notices==0.0.8- h5py==3.8.0- huggingface-hub==0.16.4- idna==3.4- imageio==2.31.2- imageio-ffmpeg==0.4.9- importlib-metadata==6.7.0- importlib-resources==5.12.0- jax==0.3.24- jaxlib==0.3.24+cuda11.cudnn82- joblib==1.3.2- kiwisolver==1.4.5- labmaze==1.0.6- lxml==4.9.3- matplotlib==3.5.3- ml-collections==0.1.1- msgpack==1.0.5- mujoco==2.3.6- mujoco-py==2.0.2.13- numpy==1.21.6- nvidia-cublas-cu11==11.10.3.66- nvidia-cuda-nvrtc-cu11==11.7.99- nvidia-cuda-runtime-cu11==11.7.99- nvidia-cudnn-cu11==8.5.0.96- opt-einsum==3.3.0- optax==0.1.4- packaging==23.1- pathtools==0.1.2- pillow==9.5.0- protobuf==3.20.1- psutil==5.9.5- pybullet==3.2.5- pycparser==2.21- pyglet==1.5.0- pygments==2.16.1- pyopengl==3.1.7- pyparsing==3.1.1- pysocks==1.7.1- python-dateutil==2.8.2- pyyaml==6.0.1- regex==2023.8.8- requests==2.31.0- rich==11.2.0- safetensors==0.3.3- scikit-learn==1.0.2- scipy==1.7.3- sentry-sdk==1.31.0- setproctitle==1.3.2- six==1.16.0- smmap==5.0.1- soupsieve==2.4.1- tensorboardx==2.1- tensorflow-probability==0.19.0- termcolor==2.3.0- threadpoolctl==3.1.0- tokenizers==0.13.3- toolz==0.12.0- torch==1.13.1- tqdm==4.66.1- transformers==4.30.2- typing-extensions==4.7.1- ujson==5.7.0- urllib3==2.0.4- wandb==0.15.10- zipp==3.15.0
prefix: /home/user_name/miniconda3/envs/jax