Description
Model initialization step within colab using tpu and default configuration exits with error.
Errors are nested through jax and hypernerf, but it appear that the root is
hypernerf/hypernerf/model_utils.py
Line 119 in d433ebe
within the
volumetric_rendering
function, jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
.
The relevant error is
/usr/local/lib/python3.7/dist-packages/hypernerf/model_utils.py in volumetric_rendering(rgb, sigma, z_vals, dirs, use_white_background, sample_at_infinity, eps)
113 z_vals[..., 1:] - z_vals[..., :-1],
--> 114 jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
115 ], -1)/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _broadcast_to(arr, shape)
341 return arr.broadcast_to(shape)
--> 342 _check_arraylike("broadcast_to", arr)
343 arr = arr if isinstance(arr, ndarray) else _asarray(arr)/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _check_arraylike(fun_name, *args)
294 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 295 raise TypeError(msg.format(fun_name, type(arg), pos))
296UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
A quick search brought up things like https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax which suggested all elements should be converted to the jnp arrays. Haven't gotten it working yet, though.