大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的知识进行总结与归纳,不仅形成深入且独到的理解,而且能够帮助新手快速入门。
本文主要介绍了AttributeError: module ‘jax’ has no attribute 'Array’解决方案,希望能对使用jax的同学们有所帮助。
文章目录
- 1. 问题描述
- 2. 解决方案
1. 问题描述
今天在运行jax代码时,却遇到了AttributeError: module ‘jax’ has no attribute 'Array’的错误提示,具体报错信息如下图所示:
在经过了亲身的实践后,终于找到了解决问题的方案,最终将逐步的操作过程总结如下。希望能对遇到同样bug的同学有所帮助。
2. 解决方案
经过调研和实践后发现,需要通过以下命令升级对应的Python库,需要说明的是不同环境的安装命令是不同的,具体如下所示:
如果是CPU环境,则执行以下命令进行:
pip install -U "jax[cpu]"
如果是英伟达的GPU,则执行以下命令:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
如果是谷歌的TPU,则执行以下命令:
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
当看到Sucessfully installed jax,则说明安装成功了,然后再运行代码就不会报错了。