8000 [WIP] Compute the determinant of a matrix on the GPU by juancamilog · Pull Request #6193 · Theano/Theano · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Compute the determinant of a matrix on the GPU #6193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 154 additions & 2 deletions theano/gpuarray/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

import theano
from theano import Op, config, tensor
from theano.scalar import bool as bool_t
from theano.scalar import (bool as bool_t,
neq as scalar_neq_op,
log as scalar_log_op)
from theano.gof import COp, ParamsType
from theano.gpuarray import GpuArrayType

from .basic_ops import (CGpuKernelBase, as_gpuarray_variable, gpu_contiguous,
infer_context_name)
from .elemwise import GpuElemwise
from .subtensor import GpuExtractDiag
from .type import gpu_context_type

try:
Expand Down Expand Up @@ -489,7 +493,155 @@ def gpu_cholesky(A, lower=True):
return GpuCholesky(lower)(A)


# TODO: add support for float64
class GpuLU(Op):
"""
CUSOLVER GPU LU factorization Op.

Given a non-singular square matrix, computes its LU
factorization. Useful for computing the matrix determinant

Parameters
----------

"""
__props__ = ('inplace', 'check_output')

def __init__(self, inplace=False, check_output=True):
self.inplace = inplace
self.check_output = check_output
if self.inplace:
self.destroy_map = {0: [0]}
super(GpuLU, self).__init__()

def clone_inplace(self):
return self.__class__(lower=self.lower, inplace=True)

def make_node(self, inp):
if not cusolver_available:
raise RuntimeError('CUSOLVER is not available and '
'GpuLU Op can not be constructed.')
if skcuda.__version__ <= '0.5.1':
warnings.warn('The GpuLU op requires scikit-cuda > 0.5.1 to work with CUDA 8')
if not pygpu_available:
raise RuntimeError('Missing pygpu or triu/tril functions.'
'Install or update libgpuarray.')
context_name = infer_context_name(inp)

inp = as_gpuarray_variable(inp, context_name)

inp = gpu_contiguous(inp)

# this op can only operate on float32 matrices
# because of current implementation of triu/tril.
# TODO: support float64
assert inp.ndim == 2
assert inp.dtype == 'float32'

# outputs LU in a single matrix, and a pivots array
pivots_type = GpuArrayType('int32',
broadcastable=inp[0].broadcastable,
con 10000 text_name=context_name)()
return theano.Apply(self, [inp], [inp.type(), pivots_type])

def prepare_node(self, node, storage_map, compute_map, impl):
ctx = node.inputs[0].type.context
attach_cusolver_handle_to_context(ctx)

def perform(self, node, inputs, outputs):
context = inputs[0][0].context

# Input matrix.
A = inputs[0]

l, n = A.shape
if l != n:
raise ValueError('A must be a square matrix')

lda = max(1, n)

# cusolver operates on F ordered matrices
if not self.inplace:
LU = pygpu.array(A, copy=True, order='F')
else:
LU = A.T if A.flags['C_CONTIGUOUS'] else A

LU_ptr = LU.gpudata

with context:
workspace_size = cusolver.cusolverDnSgetrf_bufferSize(
context.cusolver_handle, n, n, LU_ptr, lda)

workspace = pygpu.zeros(workspace_size, dtype='float32',
context=context)

pivots = pygpu.zeros(n, dtype='int32', context=context)

dev_info = pygpu.zeros((1,), dtype='int32', context=context)

workspace_ptr = workspace.gpudata
pivots_ptr = pivots.gpudata
dev_info_ptr = dev_info.gpudata

cusolver.cusolverDnSgetrf(
context.cusolver_handle, n, n, LU_ptr, lda, workspace_ptr,
pivots_ptr, dev_info_ptr)

if self.check_output:
val_dev_info = np.asarray(dev_info)[0]
if val_dev_info > 0:
raise LinAlgError('LU decomposition failed')

outputs[1][0] = pivots

outputs[0][0] = LU


def gpu_det(A, inplace=False):
"""
Computes the matrix determinant on the GPU using its LU
factorization; i.e. if A = PLU then det(A) = (-1)**p*prod(L)*prod(U),
where p is the number of permuted rows defined by P

Parameters
----------
A : square matrix

Returns
-------
det : determinant of A
"""
LU, pivots = GpuLU(inplace=inplace)(A)
idx = theano.tensor.arange(1, A.shape[0] + 1, dtype=pivots.dtype)
p = GpuElemwise(scalar_neq_op)(pivots, idx).sum().astype(A.dtype)
diag = GpuExtractDiag(view=True)(LU)
det = diag.prod() * ((-1)**(p))
return det


def gpu_slogdet(A, inplace=False):
"""
Computes the logartihm of the matrix determinant on the GPU using
its LU factorization; i.e. if A = PLU then
det(A) = (-1)**p*prod(L)*prod(U),
where p is the number of permuted rows defined by P

Parameters
----------
A : square matrix

Returns
-------
s, logabsdet : sign of the determinant and log(|det(A)|)
"""
LU, pivots = GpuLU(inplace=inplace)(A)
idx = theano.tensor.arange(1, A.shape[0] + 1, dtype=pivots.dtype)
p = GpuElemwise(scalar_neq_op)(pivots, idx).sum().astype(A.dtype)
logabsdet = GpuElemwise(scalar_log_op)(GpuExtractDiag(view=True)(LU)).sum()
return ((-1)**(p)), logabsdet


# TODO: add support
# for float64
class GpuMagmaBase(COp):
"""Base class for magma related operations. Add the necessary headers,
libraries and optionally the location of headers and library.
Expand Down
20 changes: 19 additions & 1 deletion theano/gpuarray/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
from .linalg import (GpuCusolverSolve, MATRIX_STRUCTURES_SOLVE, GpuCholesky,
cusolver_available, GpuMagmaMatrixInverse, gpu_svd,
GpuMagmaCholesky, gpu_qr, GpuMagmaEigh,
GpuCublasTriangularSolve, cublas_available)
GpuCublasTriangularSolve, cublas_available,
GpuLU, gpu_det)
from .neighbours import GpuImages2Neibs

_logger = logging.getLogger("theano.gpuarray.opt")
Expand Down Expand Up @@ -2152,6 +2153,21 @@ def local_inplace_gpu_solve(node):
inplace=True)(*node.inputs)]


# determinant
@register_opt('fast_compile')
@op_lifter([theano.tensor.nlinalg.Det])
@register_opt2([theano.tensor.nlinalg.Det], 'fast_compile')
def local_gpu_det(op, context_name, inputs, outputs):
if not cusolver_available:
return
if inputs[0].dtype not in ['float16', 'float32']:
return
if inputs[0].dtype == 'float16':
return gpu_det(inputs[0].astype('float32')).astype('float16')
else:
return gpu_det(inputs[0])


# Cholesky decomposition
def local_gpu_cholesky(op, context_name, inputs, outputs):
if not cusolver_available:
Expand All @@ -2163,6 +2179,8 @@ def local_gpu_cholesky(op, context_name, inputs, outputs):
return op(inputs[0].astype('float32')).astype('float16')

return op


matrix_ops_db = LocalGroupDB()
matrix_ops_db2 = LocalGroupDB(local_opt=theano.gof.opt.GraphToGPULocalOptGroup)
matrix_ops_db2.__name__ = "matrix_ops_db2"
Expand Down
0