75 lines
1.8 KiB
Plaintext
Executable File
75 lines
1.8 KiB
Plaintext
Executable File
import sys
|
|
|
|
if len(sys.argv) != 2 or sys.argv[1] not in ['-cuda','-nocuda']:
|
|
print "Usage: %s <-cuda|-nocuda>" % (sys.argv[0])
|
|
sys.exit(1)
|
|
|
|
use_cuda = False
|
|
if sys.argv[1] == '-cuda':
|
|
use_cuda = True
|
|
|
|
if use_cuda:
|
|
import pycuda.autoinit
|
|
import pycuda.driver as drv
|
|
from pycuda.compiler import SourceModule
|
|
mod = SourceModule("""
|
|
__global__ void multiply_them(float *dest, float *a, float *b)
|
|
{
|
|
const int i = threadIdx.x;
|
|
dest[i] = a[i] * b[i];
|
|
}
|
|
__global__ void multiply_them_int(int *dest, int *a, int *b)
|
|
{
|
|
const int i = threadIdx.x;
|
|
dest[i] = a[i] * b[i];
|
|
}
|
|
""")
|
|
multiply_them = mod.get_function("multiply_them")
|
|
multiply_them_int = mod.get_function("multiply_them_int")
|
|
|
|
import numpy
|
|
import imathnumpy
|
|
from imath import FloatArray, IntArray
|
|
import particle
|
|
|
|
length = 10
|
|
fa = particle.random(1,5,length)
|
|
fb = particle.random(1,5,length)
|
|
fdest = FloatArray(length)
|
|
|
|
a = imathnumpy.arrayToNumpy(fa)
|
|
b = imathnumpy.arrayToNumpy(fb)
|
|
dest = imathnumpy.arrayToNumpy(fdest)
|
|
|
|
if use_cuda:
|
|
multiply_them(drv.Out(dest), drv.In(a), drv.In(b), block=(length,1,1))
|
|
else:
|
|
dest[:] = a * b
|
|
|
|
results = fdest - fa*fb
|
|
|
|
print "a:", particle.join(fa,' ')
|
|
print "b:", particle.join(fb,' ')
|
|
print "dest:", particle.join(fdest,' ')
|
|
print "diff:", particle.join(results,' ')
|
|
|
|
ia = particle.floor(particle.random(1,5,length))
|
|
ib = particle.floor(particle.random(1,5,length))
|
|
idest = IntArray(length)
|
|
|
|
a = imathnumpy.arrayToNumpy(ia)
|
|
b = imathnumpy.arrayToNumpy(ib)
|
|
dest = imathnumpy.arrayToNumpy(idest)
|
|
|
|
if use_cuda:
|
|
multiply_them_int(drv.Out(dest), drv.In(a), drv.In(b), block=(length,1,1))
|
|
else:
|
|
dest[:] = a * b
|
|
|
|
results = idest - ia*ib
|
|
|
|
print "a:", particle.join(ia,' ')
|
|
print "b:", particle.join(ib,' ')
|
|
print "dest:", particle.join(idest,' ')
|
|
print "diff:", particle.join(results,' ')
|