8000 An Issue regarding "line_search_armijo" function in a GPU Environment · Issue #445 · PythonOT/POT · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
An Issue regarding "line_search_armijo" function in a GPU Environment #445
Closed
@Nuwaisir-1998

Description

@Nuwaisir-1998

I'm using gcg method of optim.py in a GPU environment which calls line_search_armijo function. There is a call to the method scalar_search_armijo of scipy.optimize (at _linesearch.py) inside the line_search_armijo function. In this scalar_search_armijo method if the while loop starting with while alpha1 > amin: is executed, alpha2 = (-b + np.sqrt(abs(b**2 - 3 * a * derphi0))) / (3.0*a) - this statement causes an error as np is used with cuda variables. Importantly, a, and b are in cuda device because it is calculated using phi which is passed from and defined in line_search_armijo function and returns a value residing in cuda device. A simple change (removing the np dependency and taking alpha2 into cpu) inside the scalar_search_armijo method of scipy.optimize solves the issue, but I don't think that is a good idea to change scipy's code. Can you suggest a better solution to this?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0