Description
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Ubuntu 24.04 x86_64
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 1.0.0
- Java version (i.e., the o
94CE
utput of
java -version
): openjdk version "21.0.6" 2025-01-21 - Java command line flags (e.g., GC parameters):
- Python version (if transferring a model trained in Python): 3.12.8
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: 12.8.61/8905
- GPU model and memory: V100 (32GB)
Describe the current behavior
Executing the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.
Describe the expected behavior
The function calls should take a comparable amount of time.
Code to reproduce the issue
I have the following python function:
@tf.function(
input_signature=[
tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="data"), # [k, n, n]
tf.TensorSpec(shape=[1, 2048, 2048], dtype=tf.float32, name="image"), # [1, n, n]
tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="psf"), # [k, n, n]
],
jit_compile=True
)
def rl_step(
data: tf.Tensor, # [k, n, n]
image: tf.Tensor, # [1, n, n]
psf: tf.Tensor, # [k, n, n]
) -> tf.Tensor: # [k, n, n]
psf_fft = tf.signal.rfft2d(psf)
psft_fft = tf.signal.rfft2d(tf.reverse(psf, axis=(-2, -1)))
denom = tf.reduce_sum(
tf.signal.irfft2d(psf_fft * tf.signal.rfft2d(data)),
axis=0,
keepdims=True
)
img_err = image / denom
return data * tf.signal.irfft2d(tf.signal.rfft2d(img_err) * psft_fft)
In python, this function is applied iteratively over the same tensor as below:
image_tensor = tf.constant(image) # [k, n, n]
measured_psf_tensor = tf.constant(measured_psf) # [1, n, n]
data_tensor = tf.constant(data) # [k, n, n]
for i in range(10):
start = time()
data = rl_step(data_tensor, image_tensor, measured_psf_tensor)
print(f"Iter {i}:", time() - start, "seconds.")
Here image
, measured_psf
, and data
are all 3D arrays with dtype=float32 and n=2048
and k=41
This prints timings around the following:
Iter 0: 0.2061774730682373 seconds.
Iter 1: 0.004193544387817383 seconds.
Iter 2: 0.0007469654083251953 seconds.
Iter 3: 0.000415802001953125 seconds.
Iter 4: 0.0004220008850097656 seconds.
Iter 5: 0.0004246234893798828 seconds.
Iter 6: 0.0004112720489501953 seconds.
Iter 7: 0.00042128562927246094 seconds.
Iter 8: 0.0004055500030517578 seconds.
Iter 9: 0.00040721893310546875 seconds.
I tried exporting the model by adding the following after the timing code:
mod = tf.Module()
mod.f = rl_step
tf.saved_model.save(mod, "pure_tf_export")
Now I tried to use this exported mode from the Java API,
String modelLocation = "./pure_tf_export";
try(Graph g = new Graph(); Session s = new Session(g)){
SavedModelBundle model = SavedModelBundle.loader(modelLocation).load();
try (Tensor imageTensor = TFloat32.tensorOf(image);
Tensor psfTensor = TFloat32.tensorOf(psf);
Tensor dataTensor = TFloat32.tensorOf(data)
){
Map<String, Tensor> inputs = new HashMap<String, Tensor>();
inputs.put("data", dataTensor);
inputs.put("image", imageTensor);
inputs.put("psf", psfTensor);
for (int i = 0; i < 10; i++){
Instant start = Instant.now();
Result result = model.function("serving_default").call(inputs);
inputs.replace("data", result.get("output_0").get());
System.out.println("Iter " + i + " " + (Duration.between(start, Instant.now()).toMillis()/1000f) + " seconds");
}
}
}
And I get timings as follows:
Iter 0 0.701 seconds
Iter 1 0.528 seconds
Iter 2 0.874 seconds
Iter 3 0.224 seconds
Iter 4 0.254 seconds
Iter 5 1.622 seconds
Iter 6 0.241 seconds
Iter 7 0.224 seconds
Iter 8 0.231 seconds
Iter 9 0.228 seconds
I am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use tf.constant
the timings go up a lot.
Any help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.
Thanks again!