8000 Disparate Performance between Python and Java · Issue #602 · tensorflow/java · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Disparate Performance between Python and Java #602
Open
@ryanhausen

Description

@ryanhausen

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Ubuntu 24.04 x86_64
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.0.0
  • Java version (i.e., the o 94CE utput of java -version): openjdk version "21.0.6" 2025-01-21
  • Java command line flags (e.g., GC parameters):
  • Python version (if transferring a model trained in Python): 3.12.8
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 12.8.61/8905
  • GPU model and memory: V100 (32GB)

Describe the current behavior

Executing the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.

Describe the expected behavior

The function calls should take a comparable amount of time.

Code to reproduce the issue

I have the following python function:

@tf.function(
    input_signature=[
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="data"),  # [k, n, n]
            tf.TensorSpec(shape=[1, 2048, 2048], dtype=tf.float32, name="image"),  # [1, n, n]
            tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name="psf"),  # [k, n, n]
    ],
    jit_compile=True
)
def rl_step(
    data: tf.Tensor,  # [k, n, n]
    image: tf.Tensor, # [1, n, n]
    psf: tf.Tensor,   # [k, n, n]
) -> tf.Tensor: # [k, n, n]
    psf_fft = tf.signal.rfft2d(psf)
    psft_fft = tf.signal.rfft2d(tf.reverse(psf, axis=(-2, -1)))
    denom = tf.reduce_sum(
        tf.signal.irfft2d(psf_fft * tf.signal.rfft2d(data)),
        axis=0,
        keepdims=True
    )
    img_err = image / denom
    return data * tf.signal.irfft2d(tf.signal.rfft2d(img_err) * psft_fft)

In python, this function is applied iteratively over the same tensor as below:

    image_tensor = tf.constant(image) # [k, n, n]
    measured_psf_tensor = tf.constant(measured_psf) # [1, n, n]
    data_tensor = tf.constant(data) # [k, n, n]

    for i in range(10):
        start = time()
        data = rl_step(data_tensor, image_tensor, measured_psf_tensor)
        print(f"Iter {i}:", time() - start, "seconds.")

Here image, measured_psf, and data are all 3D arrays with dtype=float32 and n=2048 and k=41

This prints timings around the following:

Iter 0: 0.2061774730682373 seconds.
Iter 1: 0.004193544387817383 seconds.
Iter 2: 0.0007469654083251953 seconds.
Iter 3: 0.000415802001953125 seconds.
Iter 4: 0.0004220008850097656 seconds.
Iter 5: 0.0004246234893798828 seconds.
Iter 6: 0.0004112720489501953 seconds.
Iter 7: 0.00042128562927246094 seconds.
Iter 8: 0.0004055500030517578 seconds.
Iter 9: 0.00040721893310546875 seconds.

I tried exporting the model by adding the following after the timing code:

    mod = tf.Module()
    mod.f = rl_step
    tf.saved_model.save(mod, "pure_tf_export")

Now I tried to use this exported mode from the Java API,

        String modelLocation = "./pure_tf_export";
        try(Graph g = new Graph(); Session s = new Session(g)){
            SavedModelBundle model = SavedModelBundle.loader(modelLocation).load();

            try (Tensor imageTensor = TFloat32.tensorOf(image);
                Tensor psfTensor = TFloat32.tensorOf(psf);
                Tensor dataTensor = TFloat32.tensorOf(data)
            ){
                Map<String, Tensor> inputs = new HashMap<String, Tensor>();
                inputs.put("data", dataTensor);
                inputs.put("image", imageTensor);
                inputs.put("psf", psfTensor);

                for (int i = 0; i < 10; i++){

                    Instant start = Instant.now();

                    Result result = model.function("serving_default").call(inputs);
                    inputs.replace("data", result.get("output_0").get());

                    System.out.println("Iter " + i + " " + (Duration.between(start, Instant.now()).toMillis()/1000f) + " seconds");
                }
            }
        }

And I get timings as follows:

Iter 0 0.701 seconds
Iter 1 0.528 seconds
Iter 2 0.874 seconds
Iter 3 0.224 seconds
Iter 4 0.254 seconds
Iter 5 1.622 seconds
Iter 6 0.241 seconds
Iter 7 0.224 seconds
Iter 8 0.231 seconds
Iter 9 0.228 seconds

I am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use tf.constant the timings go up a lot.

Any help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.

Thanks again!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0