378 lines
13 KiB
C
Executable File
378 lines
13 KiB
C
Executable File
/******************************************************************
|
|
* mexProd2.c : C-MEX file to compute the product of two matrices.
|
|
*
|
|
* P = mexProd2(blk,A,B,type)
|
|
*
|
|
* input: blk = 1x2 cell array describing the block structure of A and B
|
|
* A = mxn matrix.
|
|
* B = nxp matrix.
|
|
* type = 0 general matrix product
|
|
* 1 if P is symmetric
|
|
* 21 if B is upper triangular
|
|
* 211 if B is upper triangular, P is symmetric
|
|
* 22 if B is lower triangular
|
|
* 221 if B is lower triangular, P is symmetric
|
|
* 31 if A is upper triangular
|
|
* 311 if A is upper triangular, P is symmetric
|
|
* 32 if A is lower triangular
|
|
* 321 if A is lower triangular, P is symmetric
|
|
* 41 if A is upper triangular, B is upper triangular
|
|
* 42 if A is upper triangular, B is lower triangualr
|
|
* 421 if A is upper triangular, B is lower triangualr,
|
|
* P is symmetric
|
|
*
|
|
* SDPT3: version 3.0
|
|
* Copyright (c) 1997 by
|
|
* K.C. Toh, M.J. Todd, R.H. Tutuncu
|
|
* Last Modified: 2 Feb 01
|
|
******************************************************************/
|
|
|
|
#include <math.h>
|
|
#include <mex.h>
|
|
|
|
static int min1, min2;
|
|
#define IMIN(a,b) (min1=(a),min2=(b),(min1)<(min2)?(min1):(min2))
|
|
|
|
/**********************************************************
|
|
* saxpy: z = z + alpha*y
|
|
**********************************************************/
|
|
static void saxpy(const double x, const double *y, const int idx1,
|
|
double *z, const int idx2, const istart, const iend)
|
|
{ int i;
|
|
|
|
for(i=istart; i< iend-3; i++){ /* LEVEL 4 */
|
|
z[i+idx2] += x * y[i+idx1]; i++;
|
|
z[i+idx2] += x * y[i+idx1]; i++;
|
|
z[i+idx2] += x * y[i+idx1]; i++;
|
|
z[i+idx2] += x * y[i+idx1];
|
|
}
|
|
if(i < iend-1){ /* LEVEL 2 */
|
|
z[i+idx2] += x * y[i+idx1]; i++;
|
|
z[i+idx2] += x * y[i+idx1]; i++;
|
|
}
|
|
if(i < iend){ /* LEVEL 1 */
|
|
z[i+idx2] += x * y[i+idx1];
|
|
}
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* form P using the upper triangular part
|
|
**********************************************************/
|
|
void symmetrize(double *P, int n)
|
|
{
|
|
int j, k, jn;
|
|
|
|
for (j=0; j<n; ++j){
|
|
jn = j*n;
|
|
for (k=0; k<j ; ++k){ P[j+k*n] = P[k+jn]; }
|
|
}
|
|
return;
|
|
}
|
|
/********************************************************
|
|
* determine istart, iend
|
|
********************************************************/
|
|
void product_idx(int index[2], int j, int k, int m, int type)
|
|
|
|
{ if (type == 1 || type == 211 || type == 221)
|
|
{ index[0] = 0; index[1] = j+1; }
|
|
else if (type == 0 || type == 21 || type == 22)
|
|
{ index[0] = 0; index[1] = m; }
|
|
else if (type == 31 || type == 41 || type == 42)
|
|
{ index[0] = 0; index[1] = k+1; }
|
|
else if (type == 311 || type == 421)
|
|
{ index[0] = 0; index[1] = IMIN(k,j)+1; }
|
|
else if (type == 32)
|
|
{ index[0] = k; index[1] = m; }
|
|
else if (type == 321)
|
|
{ index[0] = k; index[1] = j+1; }
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* A dense, B dense
|
|
**********************************************************/
|
|
static void product(double *A, double *B, double *P,
|
|
int m, int n, int p, int type)
|
|
|
|
{ int i, j, k, jm, jn, km, kstart, kend, istart, iend;
|
|
int index[2];
|
|
double tmp;
|
|
|
|
for (j=0; j<p; ++j){
|
|
if (type==0 || type==1 || type==31 || type==311 || type==32)
|
|
{ kstart = 0; kend = n; }
|
|
else if (type == 22 || type == 221 || type == 42 || type == 421)
|
|
{ kstart = j; kend = n; }
|
|
else if (type == 21 || type == 211 || type == 321 || type == 41)
|
|
{ kstart = 0; kend = j+1; }
|
|
jm = j*m; jn = j*n;
|
|
for (k=kstart; k<kend; ++k){
|
|
product_idx(index, j, k, m, type);
|
|
istart = index[0]; iend = index[1];
|
|
tmp = B[k+jn];
|
|
if (tmp != 0) {
|
|
km = k*m;
|
|
saxpy(tmp,A,km,P,jm,istart,iend); }
|
|
}
|
|
}
|
|
if (type==1 || type==211 || type==221 || type==311 || type==321 || type==421)
|
|
{ symmetrize(P,m); }
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* A dense, B sparse
|
|
**********************************************************/
|
|
static void product2(double *A, double *B, int *irB, int *jcB,
|
|
double *P, int m, int n, int p, int type)
|
|
|
|
{ int i, j, k, r, kstart, kend, istart, iend, jm, rm;
|
|
int index[2];
|
|
double tmp;
|
|
|
|
|
|
for (j=0; j<p; ++j){
|
|
kstart = jcB[j];
|
|
kend = jcB[j+1];
|
|
jm = j*m;
|
|
for (k=kstart; k<kend; ++k){
|
|
r = irB[k];
|
|
tmp = B[k];
|
|
product_idx(index, j, r, m, type);
|
|
istart = index[0]; iend = index[1];
|
|
if (tmp != 0) {
|
|
rm = r*m;
|
|
saxpy(tmp,A,rm,P,jm,istart,iend); }
|
|
}
|
|
}
|
|
if (type==1 || type==211 || type==221 || type==311 || type==321 || type==421)
|
|
{ symmetrize(P,m); }
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* A sparse, B dense
|
|
**********************************************************/
|
|
void product3_idx(int index[3], int j, int n, int type)
|
|
|
|
{ int kstart, kend, sym;
|
|
|
|
if (type == 1 || type == 311 || type == 321)
|
|
{ kstart = 0; kend = n; sym = 1; }
|
|
else if (type == 211)
|
|
{ kstart = 0; kend = j+1; sym = 1; }
|
|
else if (type == 221 || type == 421)
|
|
{ kstart = j; kend = n; sym = 1; }
|
|
else if (type == 0 || type == 31 || type == 32)
|
|
{ kstart = 0; kend = n; sym = 0; }
|
|
else if (type == 21 || type == 41)
|
|
{ kstart = 0; kend = j+1; sym = 0; }
|
|
else if (type == 22 || type == 42)
|
|
{ kstart = j; kend = n; sym = 0; }
|
|
index[0] = kstart;
|
|
index[1] = kend;
|
|
index[2] = sym;
|
|
return;
|
|
}
|
|
/**********************************************************/
|
|
static void product3(double *A, int *irA, int *jcA, double *B,
|
|
double *P, int m, int n, int p, int type)
|
|
|
|
{ int i, j, k, ri, kstart, kend, istart, iend, jm, jn, sym;
|
|
int index[3];
|
|
double tmp;
|
|
|
|
for (j=0; j<p; ++j){
|
|
product3_idx(index, j, n, type);
|
|
kstart = index[0]; kend = index[1]; sym = index[2];
|
|
jm = j*m; jn = j*n;
|
|
for (k=kstart; k<kend; ++k){
|
|
tmp = B[k+jn];
|
|
istart = jcA[k];
|
|
iend = jcA[k+1];
|
|
if (tmp != 0) {
|
|
for (i=istart; i<iend; ++i) {
|
|
ri = irA[i];
|
|
if (ri > j & sym) { break; }
|
|
P[ri+jm] += tmp*A[i]; }
|
|
}
|
|
}
|
|
}
|
|
if (type==1 || type==211 || type==221 || type==311 || type==321 || type==421)
|
|
{ symmetrize(P,m); }
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* A sparse, B sparse
|
|
**********************************************************/
|
|
static void product4(double *A, int *irA, int *jcA,
|
|
double *B, int *irB, int *jcB,
|
|
double *P, int *irP, int *jcP,
|
|
double *Ptmp, int numblk, int *cumblk)
|
|
|
|
{ int i, j, k, l, r, t, istart, iend, kstart, kend, jstart, jend;
|
|
int idx;
|
|
double tmp;
|
|
|
|
idx = 0; jcP[0]=0;
|
|
for (l=0; l<numblk; ++l) {
|
|
jstart = cumblk[l]; jend = cumblk[l+1];
|
|
for (j=jstart; j<jend; ++j){
|
|
kstart = jcB[j]; kend = jcB[j+1];
|
|
/**** forming jth column of P ****/
|
|
for (k=kstart; k<kend; ++k) {
|
|
r = irB[k];
|
|
tmp = B[k];
|
|
istart = jcA[r]; iend = jcA[r+1];
|
|
for (i=istart; i<iend; ++i) {
|
|
t = irA[i];
|
|
Ptmp[t] += tmp*A[i]; }
|
|
}
|
|
for (k=jstart; k<jend; ++k) {
|
|
tmp = Ptmp[k];
|
|
if (tmp != 0) {
|
|
P[idx] = tmp; irP[idx] = k;
|
|
Ptmp[k] = 0; idx++; }
|
|
}
|
|
jcP[j+1] = idx;
|
|
}
|
|
}
|
|
jcP[jend] = idx;
|
|
return;
|
|
}
|
|
/**********************************************************
|
|
* elementwise product of two real column vectors.
|
|
**********************************************************/
|
|
static void product5(double *A, int *irA, int *jcA,
|
|
double *B, int *irB, int *jcB, double *P,
|
|
int n, int isspA, int isspB)
|
|
|
|
{ int k, kx, ky, kx2, ky2, rx, ry;
|
|
|
|
if ( !isspA & !isspB ) {
|
|
for (k=0; k<n; k++){ P[k] = A[k]*B[k]; }
|
|
}
|
|
else if ( isspA & !isspB ) {
|
|
kx = jcA[0]; kx2 = jcA[1];
|
|
for (k=kx; k<kx2; k++) {
|
|
rx = irA[k];
|
|
P[rx] = A[k]*B[rx]; }
|
|
}
|
|
else if ( !isspA & isspB ) {
|
|
ky = jcB[0]; ky2 = jcB[1];
|
|
for (k=ky; k<ky2; k++) {
|
|
ry = irB[k];
|
|
P[ry] = A[ry]*B[k]; }
|
|
}
|
|
else if ( isspA & isspB ) {
|
|
kx = jcA[0]; kx2 = jcA[1]; rx = irA[kx];
|
|
ky = jcB[0]; ky2 = jcB[1]; ry = irB[ky];
|
|
while ( (kx<kx2) & (ky<ky2) ){
|
|
if (rx == ry) {
|
|
P[rx] = A[kx]*B[ky];
|
|
kx++; ky++;
|
|
rx = irA[kx];
|
|
ry = irB[ky]; }
|
|
else if (rx < ry) {
|
|
kx++;
|
|
rx = irA[kx]; }
|
|
else {
|
|
ky++;
|
|
ry = irB[ky]; }
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
/**********************************************************/
|
|
void mexFunction(
|
|
int nlhs, mxArray *plhs[],
|
|
int nrhs, const mxArray *prhs[] )
|
|
{
|
|
mxArray *blk_cell_pr;
|
|
double *A, *B, *P, *blksize, *Ptmp;
|
|
int *irA, *jcA, *irB, *jcB, *irP, *jcP, *cumblk;
|
|
int isspA, isspB;
|
|
int m1, n1, m2, n2;
|
|
int type, index, numblk, NZmax, cols, i, l;
|
|
int subs[2];
|
|
int nsubs=2;
|
|
|
|
/* Check for proper number of arguments */
|
|
if (nrhs<3){
|
|
mexErrMsgTxt("mexProd2: requires at least 3 input arguments."); }
|
|
else if (nlhs>2){
|
|
mexErrMsgTxt("mexProd2: requires 1 output argument."); }
|
|
if (mxIsCell(prhs[1]) || mxIsCell(prhs[2])) {
|
|
mexErrMsgTxt("mexProd2: 2ND and 3RD input must both be matrices"); }
|
|
if (mxGetM(prhs[0]) > 1) {
|
|
mexErrMsgTxt("mexProd2: blk can only have 1 row"); }
|
|
|
|
/*** get pointers ***/
|
|
|
|
if (nrhs > 3) { type = (int)*mxGetPr(prhs[3]); }
|
|
else { type = 0; }
|
|
|
|
subs[0] = 0; subs[1] = 1;
|
|
index = mxCalcSingleSubscript(prhs[0],nsubs,subs);
|
|
blk_cell_pr = mxGetCell(prhs[0],index);
|
|
blksize = mxGetPr(blk_cell_pr);
|
|
numblk = mxGetN(blk_cell_pr);
|
|
cumblk = mxCalloc(numblk+1,sizeof(int));
|
|
NZmax = 0;
|
|
for (l=0; l<numblk; l++) {
|
|
cols = (int)blksize[l];
|
|
cumblk[l+1] = cumblk[l] + cols;
|
|
NZmax += cols*cols; }
|
|
|
|
A = mxGetPr(prhs[1]);
|
|
m1 = mxGetM(prhs[1]);
|
|
n1 = mxGetN(prhs[1]);
|
|
isspA = mxIsSparse(prhs[1]);
|
|
if (isspA) { irA = mxGetIr(prhs[1]);
|
|
jcA = mxGetJc(prhs[1]); }
|
|
|
|
B = mxGetPr(prhs[2]);
|
|
m2 = mxGetM(prhs[2]);
|
|
n2 = mxGetN(prhs[2]);
|
|
isspB = mxIsSparse(prhs[2]);
|
|
if (isspB) { irB = mxGetIr(prhs[2]);
|
|
jcB = mxGetJc(prhs[2]); }
|
|
|
|
if ((n1!=m2) & !(n1==1 & n2==1)) {
|
|
mexErrMsgTxt("mexProd2: 2ND and 3RD input not compatible"); }
|
|
if ((numblk > 1) & !(isspA & isspB) & !(n1==1 & n2==1)) {
|
|
mexErrMsgTxt("mexProd2: 2ND and 3RD must be both sparse"); }
|
|
|
|
/***** create return argument *****/
|
|
|
|
if (isspA & isspB & !(n1==1 & n2==1)){
|
|
plhs[0] = mxCreateSparse(m1,n2,NZmax,mxREAL);
|
|
P = mxGetPr(plhs[0]); irP = mxGetIr(plhs[0]); jcP = mxGetJc(plhs[0]); }
|
|
else {
|
|
plhs[0] = mxCreateDoubleMatrix(m1,n2,mxREAL);
|
|
P = mxGetPr(plhs[0]);
|
|
}
|
|
if (isspA & isspB & !(n1==1 & n2==1)) {
|
|
Ptmp = mxCalloc(cumblk[numblk],sizeof(double));
|
|
}
|
|
/**********************************************
|
|
* Do the actual computations in a subroutine
|
|
**********************************************/
|
|
|
|
if (m1 == m2 & n1 == 1 & n2 == 1) {
|
|
product5(A, irA, jcA, B, irB, jcB, P, m1, isspA, isspB);
|
|
} else {
|
|
if (!isspA & !isspB){
|
|
product(A, B, P, m1, n1, n2, type); }
|
|
else if (!isspA & isspB){
|
|
product2(A, B, irB, jcB, P, m1, n1, n2, type); }
|
|
else if (isspA & !isspB){
|
|
product3(A, irA, jcA, B, P, m1, n1, n2, type); }
|
|
else if (isspA & isspB){
|
|
product4(A, irA, jcA, B, irB, jcB,P,irP,jcP,Ptmp,numblk,cumblk);
|
|
}
|
|
}
|
|
mxFree(cumblk);
|
|
if (isspA & isspB & !(n1==1 & n2==1)) { mxFree(Ptmp); }
|
|
return;
|
|
}
|
|
/**********************************************************/
|
|
|