diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java index 9e79b024a2f..e03ebd8cac7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java @@ -24,8 +24,9 @@ import java.util.Optional; import java.util.concurrent.Future; import java.util.stream.IntStream; +import java.util.regex.Matcher; +import java.util.regex.Pattern; -import org.apache.commons.lang3.tuple.ImmutableTriple; import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; @@ -103,45 +104,71 @@ else if(moAligned) private void processAlignedFedCov(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixLineagePair moLin3) { FederatedRequest fr1; - if(moLin3 == null) + if(moLin3 == null) { fr1 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}); - else + } + else { fr1 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2, input3}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), moLin3.getFedMapping().getID()}); - + } + FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID()); FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID()); Future[] covTmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3); //means - Future[] meanTmp1 = processMean(mo1, 0); - Future[] meanTmp2 = processMean(mo2, 1); + Future[] meanTmp1 = processMean(mo1, moLin3, 0); + Future[] meanTmp2 = processMean(mo2, moLin3, 1); - ImmutableTriple res = getResponses(covTmp, meanTmp1, meanTmp2); + Double[] cov = getResponses(covTmp); + Double[] mean1 = getResponses(meanTmp1); + Double[] mean2 = getResponses(meanTmp2); - double result = aggCov(res.left, res.middle, res.right, mo1.getFedMapping().getFederatedRanges()); - ec.setVariable(output.getName(), new DoubleObject(result)); + if (moLin3 == null) { + double result = aggCov(cov, mean1, mean2, mo1.getFedMapping().getFederatedRanges()); + ec.setVariable(output.getName(), new DoubleObject(result)); + } + else { + Future[] weightsSumTmp = getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString, moLin3.getFedMapping()); + Double[] weights = getResponses(weightsSumTmp); + + double result = aggWeightedCov(cov, mean1, mean2, weights); + ec.setVariable(output.getName(), new DoubleObject(result)); + } } private void processFedCovWeights(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixLineagePair moLin3) { + + FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(moLin3, false); + + // the original instruction encodes weights as "pREADW", change to the new ID + String[] parts = instString.split("°"); + String covInstr = instString.replace(parts[4], String.valueOf(fr1[0].getID()) + "·MATRIX·FP64"); - FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(moLin3, false); - FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}); - FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID()); - FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID()); - Future[] covTmp = mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4); + FederatedRequest fr2 = FederationUtils.callInstruction( + covInstr, output, + new CPOperand[]{input1, input2, input3}, + new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), fr1[0].getID()} + ); + FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID()); + FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr2.getID()); + Future[] covTmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4); //means - Future[] meanTmp1 = processMean(mo1, 0); - Future[] meanTmp2 = processMean(mo2, 1); + Future[] meanTmp1 = processMean(mo1, 0, fr1[0].getID()); + Future[] meanTmp2 = processMean(mo2, 1, fr1[0].getID()); - ImmutableTriple res = getResponses(covTmp, meanTmp1, meanTmp2); + Double[] cov = getResponses(covTmp); + Double[] mean1 = getResponses(meanTmp1); + Double[] mean2 = getResponses(meanTmp2); - double result = aggCov(res.left, res.middle, res.right, mo1.getFedMapping().getFederatedRanges()); + Future[] weightsSumTmp = getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping()); + Double[] weights = getResponses(weightsSumTmp); + + double result = aggWeightedCov(cov, mean1, mean2, weights); ec.setVariable(output.getName(), new DoubleObject(result)); } @@ -174,11 +201,17 @@ private void processCov(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2) } // with weights else { - MatrixBlock wtBlock = ec.getMatrixInput(input2.getName()); - response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new CovarianceFEDInstruction.COVWeightsFunction(data.getVarID(), - mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1), - cop, wtBlock))).get(); + MatrixBlock wtBlock = ec.getMatrixInput(input3.getName()); + response = data.executeFederatedOperation( + new FederatedRequest( + FederatedRequest.RequestType.EXEC_UDF, -1, + new CovarianceFEDInstruction.COVWeightsFunction( + data.getVarID(), + mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1), + cop, wtBlock.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1) + ) + ) + ).get(); } if(!response.isSuccessful()) @@ -202,59 +235,345 @@ private void processCov(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2) } } - private static ImmutableTriple getResponses(Future[] covFfr, Future[] mean1Ffr, Future[] mean2Ffr) { - Double[] cov = new Double[covFfr.length]; - Double[] mean1 = new Double[mean1Ffr.length]; - Double[] mean2 = new Double[mean2Ffr.length]; - IntStream.range(0, covFfr.length).forEach(i -> { + private static Double[] getResponses(Future[] ffr) { + Double[] fr = new Double[ffr.length]; + IntStream.range(0, fr.length).forEach(i -> { try { - cov[i] = ((ScalarObject) covFfr[i].get().getData()[0]).getDoubleValue(); - mean1[i] = ((ScalarObject) mean1Ffr[1].get().getData()[0]).getDoubleValue(); - mean2[i] = ((ScalarObject) mean2Ffr[2].get().getData()[0]).getDoubleValue(); + fr[i] = ((ScalarObject) ffr[i].get().getData()[0]).getDoubleValue(); } catch(Exception e) { throw new DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov."); } }); - return new ImmutableTriple<>(cov, mean1, mean2); + return fr; } private static double aggCov(Double[] covValues, Double[] mean1, Double[] mean2, FederatedRange[] ranges) { - double cov = covValues[0]; - long size1 = ranges[0].getSize(); - double mean = (mean1[0] + mean2[0]) / 2; - - for(int i = 0; i < covValues.length - 1; i++) { - long size2 = ranges[i+1].getSize(); - double nextMean = (mean1[i+1] + mean2[i+1]) / 2; - double newMean = (size1 * mean + size2 * nextMean) / (size1 + size2); + long[] sizes = new long[ranges.length]; + for (int i = 0; i < ranges.length; i++) { + sizes[i] = ranges[i].getSize(); + } + + // calculate global means + double totalMeanX = 0; + double totalMeanY = 0; + int totalCount = 0; + for (int i = 0; i < mean1.length; i++) { + totalMeanX += mean1[i] * sizes[i]; + totalMeanY += mean2[i] * sizes[i]; + totalCount += sizes[i]; + } + + totalMeanX /= totalCount; + totalMeanY /= totalCount; + + // calculate global covariance + double cov = 0; + for (int i = 0; i < covValues.length; i++) { + cov += (sizes[i] - 1) * covValues[i]; + cov += sizes[i] * (mean1[i] - totalMeanX) * (mean2[i] - totalMeanY); + } + return cov / (totalCount - 1); // adjusting for degrees of freedom + } - cov = (size1 * cov + size2 * covValues[i+1] + size1 * (mean - newMean) * (mean - newMean) - + size2 * (nextMean - newMean) * (nextMean - newMean)) / (size1 + size2); + private static double aggWeightedCov(Double[] covValues, Double[] mean1, Double[] mean2, Double[] weights) { + // calculate global weighted means + double totalWeightedMeanX = 0; + double totalWeightedMeanY = 0; + double totalWeight = 0; + for (int i = 0; i < mean1.length; i++) { + totalWeight += weights[i]; + totalWeightedMeanX += mean1[i] * weights[i]; + totalWeightedMeanY += mean2[i] * weights[i]; + } + + totalWeightedMeanX /= totalWeight; + totalWeightedMeanY /= totalWeight; + + // calculate global weighted covariance + double cov = 0; + for (int i = 0; i < covValues.length; i++) { + cov += (weights[i] - 1) * covValues[i]; + cov += weights[i] * (mean1[i] - totalWeightedMeanX) * (mean2[i] - totalWeightedMeanY); + } + return cov / (totalWeight - 1); // adjusting for degrees of freedom + } - mean = newMean; - size1 = size1 + size2; + private Future[] processMean(MatrixObject mo1, MatrixLineagePair moLin3, int var){ + String[] parts = instString.split("°"); + Future[] meanTmp = null; + if (moLin3 == null) { + String meanInstr = instString.replace(getOpcode(), getOpcode().replace("cov", "uamean")); + meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) + "°", ""); + meanInstr = meanInstr.replace(parts[4], parts[4].replace("FP64", "STRING°16")); + + //create federated commands for aggregation + FederatedRequest meanFr1 = FederationUtils.callInstruction(meanInstr, output, + new CPOperand[]{var == 0 ? input2 : input1}, new long[]{mo1.getFedMapping().getID()}); + FederatedRequest meanFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID()); + FederatedRequest meanFr3 = mo1.getFedMapping().cleanup(getTID(), meanFr1.getID()); + + meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1, meanFr2, meanFr3); } - return cov; + else { + // multiply input X by weights W element-wise + String multOutput = incrementVar(parts[4], 1); + String multInstr = instString + .replace(getOpcode(), getOpcode().replace("cov", "*")) + .replace((var == 0 ? parts[2] : parts[3]) + "°", "") + .replace(parts[5], multOutput); + + CPOperand multOutputCPOp = new CPOperand( + multOutput.substring(0, multOutput.indexOf("·")), + mo1.getValueType(), + mo1.getDataType() + ); + + FederatedRequest multFr = FederationUtils.callInstruction( + multInstr, + multOutputCPOp, + new CPOperand[]{var == 0 ? input2 : input1, input3}, + new long[]{mo1.getFedMapping().getID(), moLin3.getFedMapping().getID()} + ); + + // calculate the sum of the obtained vector + String[] partsMult = multInstr.split("°"); + String sumInstr1Output = incrementVar(multOutput, 1) + .replace("m", "") + .replace("MATRIX", "SCALAR"); + String sumInstr1 = multInstr + .replace(partsMult[1], "uak+") + .replace(partsMult[3] + "°", "") + .replace(partsMult[4], sumInstr1Output) + .replace(partsMult[2], multOutput); + + FederatedRequest sumFr1 = FederationUtils.callInstruction( + sumInstr1, + new CPOperand( + sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{multOutputCPOp}, + new long[]{multFr.getID()} + ); + + // calculate the sum of weights + String[] partsSum1 = sumInstr1.split("°"); + String sumInstr2Output = incrementVar(sumInstr1Output, 1); + String sumInstr2 = sumInstr1 + .replace(partsSum1[2], parts[4]) + .replace(partsSum1[3], sumInstr2Output); + + FederatedRequest sumFr2 = FederationUtils.callInstruction( + sumInstr2, + new CPOperand( + sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{input3}, + new long[]{moLin3.getFedMapping().getID()} + ); + + // divide sum(X*W) by sum(W) + String[] partsSum2 = sumInstr2.split("°"); + String divInstrOutput = incrementVar(sumInstr2Output, 1); + String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false"); + String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false"); + String divInstr = partsSum2[0] + "°" + partsSum2[1].replace("uak+", "/") + "°" + + divInstrInput1 + "°" + divInstrInput2 + "°" + divInstrOutput + "°" + partsSum2[4]; + + FederatedRequest divFr1 = FederationUtils.callInstruction( + divInstr, + new CPOperand( + divInstrOutput.substring(0, divInstrOutput.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{ + new CPOperand( + sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")), + output.getValueType(), + output.getDataType(), + output.isLiteral() + ), + new CPOperand( + sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")), + output.getValueType(), + output.getDataType(), + output.isLiteral() + ) + }, + new long[]{sumFr1.getID(), sumFr2.getID()} + ); + FederatedRequest divFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID()); + FederatedRequest divFr3 = mo1.getFedMapping().cleanup(getTID(), multFr.getID(), sumFr1.getID(), sumFr2.getID(), divFr1.getID(), divFr2.getID()); + + meanTmp = mo1.getFedMapping().execute(getTID(), multFr, sumFr1, sumFr2, divFr1, divFr2, divFr3); + } + return meanTmp; } - private Future[] processMean(MatrixObject mo1, int var){ + private Future[] processMean(MatrixObject mo1, int var, long weightsID){ String[] parts = instString.split("°"); - String meanInstr = instString.replace(getOpcode(), getOpcode().replace("cov", "uamean")); - meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) + "°", ""); - meanInstr = meanInstr.replace(parts[4], parts[4].replace("FP64", "STRING°16")); Future[] meanTmp = null; - //create federated commands for aggregation - FederatedRequest meanFr1 = FederationUtils.callInstruction(meanInstr, output, - new CPOperand[]{var == 0 ? input2 : input1}, new long[]{mo1.getFedMapping().getID()}); - FederatedRequest meanFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID()); - FederatedRequest meanFr3 = mo1.getFedMapping().cleanup(getTID(), meanFr1.getID()); - meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1, meanFr2, meanFr3); + // multiply input X by weights W element-wise + String multOutput = (var == 0 ? incrementVar(parts[2], 5) : incrementVar(parts[3], 3)); + String multInstr = instString + .replace(getOpcode(), getOpcode().replace("cov", "*")) + .replace((var == 0 ? parts[2] : parts[3]) + "°", "") + .replace(parts[4], String.valueOf(weightsID) + "·MATRIX·FP64") + .replace(parts[5], multOutput); + + CPOperand multOutputCPOp = new CPOperand( + multOutput.substring(0, multOutput.indexOf("·")), + mo1.getValueType(), + mo1.getDataType() + ); + + FederatedRequest multFr = FederationUtils.callInstruction( + multInstr, + multOutputCPOp, + new CPOperand[]{var == 0 ? input2 : input1, input3}, + new long[]{mo1.getFedMapping().getID(), weightsID} + ); + + // calculate the sum of the obtained vector + String[] partsMult = multInstr.split("°"); + String sumInstr1Output = incrementVar(multOutput, 1) + .replace("m", "") + .replace("MATRIX", "SCALAR"); + String sumInstr1 = multInstr + .replace(partsMult[1], "uak+") + .replace(partsMult[3] + "°", "") + .replace(partsMult[4], sumInstr1Output) + .replace(partsMult[2], multOutput); + + FederatedRequest sumFr1 = FederationUtils.callInstruction( + sumInstr1, + new CPOperand( + sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{multOutputCPOp}, + new long[]{multFr.getID()} + ); + + // calculate the sum of weights + String[] partsSum1 = sumInstr1.split("°"); + String sumInstr2Output = incrementVar(sumInstr1Output, 1); + String sumInstr2 = sumInstr1 + .replace(partsSum1[2], String.valueOf(weightsID) + "·MATRIX·FP64") + .replace(partsSum1[3], sumInstr2Output); + + FederatedRequest sumFr2 = FederationUtils.callInstruction( + sumInstr2, + new CPOperand( + sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{input3}, + new long[]{weightsID} + ); + + // divide sum(X*W) by sum(W) + String[] partsSum2 = sumInstr2.split("°"); + String divInstrOutput = incrementVar(sumInstr2Output, 1); + String divInstrInput1 = partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false"); + String divInstrInput2 = partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false"); + String divInstr = partsSum2[0] + "°" + partsSum2[1].replace("uak+", "/") + "°" + + divInstrInput1 + "°" + divInstrInput2 + "°" + divInstrOutput + "°" + partsSum2[4]; + + FederatedRequest divFr1 = FederationUtils.callInstruction( + divInstr, + new CPOperand( + divInstrOutput.substring(0, divInstrOutput.indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{ + new CPOperand( + sumInstr1Output.substring(0, sumInstr1Output.indexOf("·")), + output.getValueType(), + output.getDataType(), + output.isLiteral() + ), + new CPOperand( + sumInstr2Output.substring(0, sumInstr2Output.indexOf("·")), + output.getValueType(), + output.getDataType(), + output.isLiteral() + ) + }, + new long[]{sumFr1.getID(), sumFr2.getID()} + ); + FederatedRequest divFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID()); + FederatedRequest divFr3 = mo1.getFedMapping().cleanup(getTID(), multFr.getID(), sumFr1.getID(), sumFr2.getID(), divFr1.getID(), divFr2.getID()); + + meanTmp = mo1.getFedMapping().execute(getTID(), multFr, sumFr1, sumFr2, divFr1, divFr2, divFr3); return meanTmp; } + private Future[] getWeightsSum(MatrixLineagePair moLin3, long weightsID, String instString, FederationMap fedMap) { + Future[] weightsSumTmp = null; + + String[] parts = instString.split("°"); + if (!instString.contains("pREADW")) { + String sumInstr = "CP°uak+°" + parts[4] + "°" + parts[5] + "°" + parts[6]; + + FederatedRequest sumFr = FederationUtils.callInstruction( + sumInstr, + new CPOperand( + parts[5].substring(0, parts[5].indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{input3}, + new long[]{weightsID} + ); + FederatedRequest sumFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID()); + FederatedRequest sumFr3 = moLin3.getFedMapping().cleanup(getTID(), sumFr.getID()); + + weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, sumFr3); + } + else { + String sumInstr = "CP°uak+°" + String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6]; + FederatedRequest sumFr = FederationUtils.callInstruction( + sumInstr, + new CPOperand( + parts[5].substring(0, parts[5].indexOf("·")), + output.getValueType(), + output.getDataType() + ), + new CPOperand[]{input3}, + new long[]{weightsID} + ); + FederatedRequest sumFr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID()); + FederatedRequest sumFr3 = fedMap.cleanup(getTID(), sumFr.getID()); + + weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, sumFr3); + } + return weightsSumTmp; + } + + private static String incrementVar(String str, int inc) { + StringBuilder strOut = new StringBuilder(str); + Pattern pattern = Pattern.compile("\\d+"); + Matcher matcher = pattern.matcher(strOut); + if (matcher.find()) { + int num = Integer.parseInt(matcher.group()) + inc; + int start = matcher.start(); + int end = matcher.end(); + strOut.replace(start, end, String.valueOf(num)); + } + return strOut.toString(); + } + private static class COVFunction extends FederatedUDF { private static final long serialVersionUID = -501036588060113499L; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java index 136cdde7f96..24e9fc3c055 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java @@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase { private final static String TEST_NAME1 = "FederatedCovarianceTest"; private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest"; + private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest"; + private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest"; + private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest"; private final static String TEST_DIR = "functions/federated/"; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/"; @@ -64,19 +67,37 @@ public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"})); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"})); } @Test public void testCovCP() { - runCovTest(ExecMode.SINGLE_NODE, false); + runCovarianceTest(ExecMode.SINGLE_NODE, false); } @Test public void testAlignedCovCP() { - runCovTest(ExecMode.SINGLE_NODE, true); + runCovarianceTest(ExecMode.SINGLE_NODE, true); } - private void runCovTest(ExecMode execMode, boolean alignedFedInput) { + @Test + public void testCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false); + } + + @Test + public void testAlignedCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false); + } + + @Test + public void testAllAlignedCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true); + } + + private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; ExecMode platformOld = rtplatform; @@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) { DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5); + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + + int r = rows / 4; + int c = cols; + + fullDMLScriptName = ""; + + // Create 4 random 5x1 matrices + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + // Create a 20x1 weights matrix + double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3); + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c)); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + + Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); + Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); + Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); + Process t4 = startLocalFedWorker(port4); + + try { + if(!isAlive(t1, t2, t3, t4)) + throw new RuntimeException("Failed starting federated worker"); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + System.out.println(7); + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + if (alignedInput) { + // Create 4 random 5x1 matrices + double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + writeInputMatrixWithMTD("Y1", Y1, false, mc); + writeInputMatrixWithMTD("Y2", Y2, false, mc); + writeInputMatrixWithMTD("Y3", Y3, false, mc); + writeInputMatrixWithMTD("Y4", Y4, false, mc); + + if (!alignedWeights) { + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y1"), + input("Y2"), + input("Y3"), + input("Y4"), + + input("W"), + expected("S") + }; + runTest(null); + + // Run the dml script with federated matrices + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")), + + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")), + + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")), + + "in_W1=" + input("W"), + "rows=" + rows, "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + else { + double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3); + double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7); + double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8); + double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9); + + writeInputMatrixWithMTD("W1", W1, false, mc); + writeInputMatrixWithMTD("W2", W2, false, mc); + writeInputMatrixWithMTD("W3", W3, false, mc); + writeInputMatrixWithMTD("W4", W4, false, mc); + + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y1"), + input("Y2"), + input("Y3"), + input("Y4"), + + input("W1"), + input("W2"), + input("W3"), + input("W4"), + + expected("S") + }; + runTest(null); + + // Run the dml script with federated matrices and weights + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + "in_W1=" + TestUtils.federatedAddress(port1, input("W1")), + + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")), + "in_W2=" + TestUtils.federatedAddress(port2, input("W2")), + + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")), + "in_W3=" + TestUtils.federatedAddress(port3, input("W3")), + + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")), + "in_W4=" + TestUtils.federatedAddress(port4, input("W4")), + + "rows=" + rows, "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + + } + else { + // Create a random 20x1 input matrix + double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3); + writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c)); + + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y"), input("W"), expected("S") + }; + runTest(null); + + // Run the dml script with a federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + + "in_W1=" + input("W"), + "Y=" + input("Y"), + + "rows=" + rows, + "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + + // compare via files + compareResults(1e-2); + Assert.assertTrue(heavyHittersContainsString("fed_cov")); + + } + finally { + TestUtils.shutdownThreads(t1, t2, t3, t4); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml new file mode 100644 index 00000000000..da9db2f4dea --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +W = read($in_W1); # 20x1 + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml new file mode 100644 index 00000000000..ee4062f7e69 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = rbind(read($5), read($6), read($7), read($8)); # 20x1 +W = read($9); # 20x1 + +s = cov(X, Y, W); +write(s, $10); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml new file mode 100644 index 00000000000..22029de451d --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml new file mode 100644 index 00000000000..10c18f5a333 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = rbind(read($5), read($6), read($7), read($8)); # 20x1 +W = rbind(read($9), read($10), read($11), read($12)); # 20x1 + +s = cov(X, Y, W); +write(s, $13); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml new file mode 100644 index 00000000000..3ba2d5b15f8 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +Y = read($Y); # 20x1 +W = read($in_W1); # 20x1 + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml new file mode 100644 index 00000000000..db1dc7c5265 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = read($5); # 20x1 +W = read($6); # 20x1 + +s = cov(X, Y, W); +write(s, $7); \ No newline at end of file