001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.ArrayList;
016import java.util.Collections;
017import java.util.Comparator;
018
019class InterpolatedPoint {
020
021        Dataset realPoint;
022        Dataset coordPoint;
023
024        public InterpolatedPoint(Dataset realPoint, Dataset coordPoint) {
025                this.realPoint = realPoint;
026                this.coordPoint = coordPoint;
027        }
028
029        public Dataset getRealPoint() {
030                return realPoint;
031        }
032
033        public Dataset getCoordPoint() {
034                return coordPoint;
035        }
036        
037        @Override
038        public String toString() {
039                String realString = "[ " + realPoint.getDouble(0);
040                for(int i = 1; i < realPoint.getShapeRef()[0]; i++) {
041                        realString += " , " + realPoint.getDouble(i);
042                }
043                realString += " ]";
044                
045                String coordString = "[ " + coordPoint.getDouble(0);
046                for(int i = 1; i < coordPoint.getShapeRef()[0]; i++) {
047                        coordString += " , " + coordPoint.getDouble(i) ;
048                }
049                coordString += " ]";
050                
051                return realString + " : " + coordString;
052        }
053
054}
055
056public class InterpolatorUtils {
057
058        public static Dataset regridOld(Dataset data, Dataset x, Dataset y,
059                        Dataset gridX, Dataset gridY) throws Exception {
060                
061                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, gridX.getShapeRef()[0], gridY.getShapeRef()[0]);
062                
063                IndexIterator itx = gridX.getIterator();
064                
065                // need a list of lists to store points
066                ArrayList<ArrayList<InterpolatedPoint>> pointList = new ArrayList<ArrayList<InterpolatedPoint>>();
067                
068                while(itx.hasNext()){
069                        // Add a list to contain all the points which we find
070                        pointList.add(new ArrayList<InterpolatedPoint>());
071                        
072                        int xindex = itx.index;
073                        double xPos = gridX.getDouble(xindex);
074                        
075                        IndexIterator ity = gridY.getIterator();
076                        while(ity.hasNext()){
077                                int yindex = ity.index;
078                                System.out.println("Testing : "+xindex+","+yindex);
079                                double yPos = gridX.getDouble(yindex);
080                                result.set(getInterpolated(data, x, y, xPos, yPos), yindex, xindex);
081                                
082                        }
083                }
084                return result;
085        }
086        
087        
088        
089        
090        public static Dataset selectDatasetRegion(Dataset dataset, int x, int y, int xSize, int ySize) {
091                int startX = x - xSize;
092                int startY = y - ySize;
093                int endX = x + xSize + 1;
094                int endY = y + ySize +1;
095                
096                int shapeX = dataset.getShapeRef()[0];
097                int shapeY = dataset.getShapeRef()[1];
098                
099                // Do edge checking
100                if (startX < 0) {
101                        startX = 0;
102                        endX = 3;
103                } 
104                
105                if (endX > shapeX) {
106                        endX = shapeX;
107                        startX = endX-3;
108                }
109                
110                if (startY < 0) {
111                        startY = 0;
112                        endY = 3;
113                }
114                
115                if (endY > shapeY) {
116                        endY = shapeY;
117                        startY = endY-3;
118                }
119                
120                int[] start = new int[] { startX, startY };
121                int[] stop = new int[] { endX, endY };
122                
123                
124                return dataset.getSlice(start, stop, null);
125        }
126        
127        private static double getInterpolated(Dataset val, Dataset x, Dataset y, double xPos,
128                        double yPos) throws Exception {
129                
130                // initial guess
131                Dataset xPosDS = x.getSlice(new int[] {0,0}, new int[] {x.getShapeRef()[0],1}, null).isubtract(xPos);
132                int xPosMin = xPosDS.minPos()[0];
133                Dataset yPosDS = y.getSlice(new int[] {xPosMin,0}, new int[] {xPosMin+1,y.getShapeRef()[1]}, null).isubtract(yPos);
134                int yPosMin = yPosDS.minPos()[0];
135                
136                
137                // now search around there 5x5
138                
139                Dataset xClipped = selectDatasetRegion(x,xPosMin,yPosMin,2,2);
140                Dataset yClipped = selectDatasetRegion(y,xPosMin,yPosMin,2,2);
141                
142                // first find the point in the arrays nearest to the point
143                Dataset xSquare = Maths.subtract(xClipped, xPos).ipower(2);
144                Dataset ySquare = Maths.subtract(yClipped, yPos).ipower(2);
145
146                Dataset total = Maths.add(xSquare, ySquare);
147
148                int[] pos = total.minPos();
149
150                // now pull out the region around that point, as a 3x3 grid     
151                Dataset xReduced = selectDatasetRegion(x, pos[0], pos[1], 1, 1);
152                Dataset yReduced = selectDatasetRegion(y, pos[0], pos[1], 1, 1);
153                Dataset valReduced = selectDatasetRegion(val, pos[0], pos[1], 1, 1);
154
155                return getInterpolatedResultFromNinePoints(valReduced, xReduced, yReduced, xPos, yPos);
156        }
157
158        private static double getInterpolatedResultFromNinePoints(Dataset val, Dataset x, Dataset y,
159                        double xPos, double yPos) throws Exception {
160                
161                // First build the nine points
162                InterpolatedPoint p00 = makePoint(x, y, 0, 0);
163                InterpolatedPoint p01 = makePoint(x, y, 0, 1);
164                InterpolatedPoint p02 = makePoint(x, y, 0, 2);
165                InterpolatedPoint p10 = makePoint(x, y, 1, 0);
166                InterpolatedPoint p11 = makePoint(x, y, 1, 1);
167                InterpolatedPoint p12 = makePoint(x, y, 1, 2);
168                InterpolatedPoint p20 = makePoint(x, y, 2, 0);
169                InterpolatedPoint p21 = makePoint(x, y, 2, 1);
170                InterpolatedPoint p22 = makePoint(x, y, 2, 2);
171
172                // now try every connection and find points that intersect with the interpolated value
173                ArrayList<InterpolatedPoint> points = new ArrayList<InterpolatedPoint>();
174
175                InterpolatedPoint A = get1DInterpolatedPoint(p00, p10, 0, xPos);
176                InterpolatedPoint B = get1DInterpolatedPoint(p10, p20, 0, xPos);
177                InterpolatedPoint C = get1DInterpolatedPoint(p00, p01, 0, xPos);
178                InterpolatedPoint D = get1DInterpolatedPoint(p10, p11, 0, xPos);
179                InterpolatedPoint E = get1DInterpolatedPoint(p20, p21, 0, xPos);
180                InterpolatedPoint F = get1DInterpolatedPoint(p01, p11, 0, xPos);
181                InterpolatedPoint G = get1DInterpolatedPoint(p11, p21, 0, xPos);
182                InterpolatedPoint H = get1DInterpolatedPoint(p01, p02, 0, xPos);
183                InterpolatedPoint I = get1DInterpolatedPoint(p11, p12, 0, xPos);
184                InterpolatedPoint J = get1DInterpolatedPoint(p21, p22, 0, xPos);
185                InterpolatedPoint K = get1DInterpolatedPoint(p02, p12, 0, xPos);
186                InterpolatedPoint L = get1DInterpolatedPoint(p12, p22, 0, xPos);
187
188                // Now add any to the list which are not null
189                if (A != null)
190                        points.add(A);
191                if (B != null)
192                        points.add(B);
193                if (C != null)
194                        points.add(C);
195                if (D != null)
196                        points.add(D);
197                if (E != null)
198                        points.add(E);
199                if (F != null)
200                        points.add(F);
201                if (G != null)
202                        points.add(G);
203                if (H != null)
204                        points.add(H);
205                if (I != null)
206                        points.add(I);
207                if (J != null)
208                        points.add(J);
209                if (K != null)
210                        points.add(K);
211                if (L != null)
212                        points.add(L);
213
214                // if no intercepts, then retun NaN;
215                if (points.size() == 0) return Double.NaN;
216                
217                InterpolatedPoint bestPoint = null;
218
219                // sort the points by y
220                Collections.sort(points, new Comparator<InterpolatedPoint>() {
221
222                        @Override
223                        public int compare(InterpolatedPoint o1, InterpolatedPoint o2) {
224                                return (int) Math.signum(o1.realPoint.getDouble(1) - o2.realPoint.getDouble(1));
225                        }
226                });
227                
228                
229                // now we have all the points which fit the x criteria, Find the points which fit the y
230                for (int a = 1; a < points.size(); a++) {
231                        InterpolatedPoint testPoint = get1DInterpolatedPoint(points.get(a - 1), points.get(a), 1, yPos);
232                        if (testPoint != null) {
233                                bestPoint = testPoint;
234                                break;
235                        }
236                }
237
238                if (bestPoint == null) {
239                        return Double.NaN;
240                }
241
242                // now we have the best point, we can calculate the weights, and positions
243                int xs = (int) Math.floor(bestPoint.getCoordPoint().getDouble(0));
244                int ys = (int) Math.floor(bestPoint.getCoordPoint().getDouble(1));
245                
246                double xoff = bestPoint.getCoordPoint().getDouble(0) - xs;
247                double yoff = bestPoint.getCoordPoint().getDouble(1) - ys;
248
249                // check corner cases
250                if (xs == 2) {
251                        xs = 1;
252                        xoff = 1.0;
253                }
254                
255                if (ys == 2) {
256                        ys = 1;
257                        yoff = 1.0;
258                }
259                
260                double w00 = (1 - xoff) * (1 - yoff);
261                double w10 = (xoff) * (1 - yoff);
262                double w01 = (1 - xoff) * (yoff);
263                double w11 = (xoff) * (yoff);
264                
265                // now using the weights, we can get the final interpolated value
266                double result = val.getDouble(xs, ys) * w00;
267                result += val.getDouble(xs + 1, ys) * w10;
268                result += val.getDouble(xs, ys + 1) * w01;
269                result += val.getDouble(xs + 1, ys + 1) * w11;
270                
271                return result;
272        }
273
274        private static InterpolatedPoint makePoint(Dataset x, Dataset y, int i, int j) {
275                Dataset realPoint = DatasetFactory.createFromObject(new double[] { x.getDouble(i, j), y.getDouble(i, j) });
276                Dataset coordPoint = DatasetFactory.createFromObject(new double[] { i, j });
277                return new InterpolatedPoint(realPoint, coordPoint);
278        }
279
280        /**
281         * Gets an interpolated position when only dealing with 1 dimension for the interpolation.
282         * 
283         * @param p1
284         *            Point 1
285         * @param p2
286         *            Point 2
287         * @param interpolationDimension
288         *            The dimension in which the interpolation should be carried out
289         * @param interpolatedValue
290         *            The value at which the interpolated point should be at in the chosen dimension
291         * @return the new interpolated point.
292         * @throws IllegalArgumentException
293         */
294        private static InterpolatedPoint get1DInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2,
295                        int interpolationDimension, double interpolatedValue) throws IllegalArgumentException {
296                
297                checkPoints(p1, p2);
298
299                if (interpolationDimension >= p1.getRealPoint().getShapeRef()[0]) {
300                        throw new IllegalArgumentException("Dimention is too large for these datasets");
301                }
302
303                double p1_n = p1.getRealPoint().getDouble(interpolationDimension);
304                double p2_n = p2.getRealPoint().getDouble(interpolationDimension);
305                double max = Math.max(p1_n, p2_n);
306                double min = Math.min(p1_n, p2_n);
307                
308                if (interpolatedValue < min || interpolatedValue > max || min==max) {
309                        return null;
310                }
311                
312                double proportion = (interpolatedValue - min) / (max - min);
313                
314                return getInterpolatedPoint(p1, p2, proportion);
315        }
316
317        /**
318         * Gets an interpolated point between 2 points given a certain proportion
319         * 
320         * @param p1
321         *            the initial point
322         * @param p2
323         *            the final point
324         * @param proportion
325         *            how far the new point is along the path between P1(0.0) and P2(1.0)
326         * @return a new point which is the interpolated point
327         */
328        private static InterpolatedPoint getInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2, double proportion) {
329
330                checkPoints(p1, p2);
331
332                if (proportion < 0 || proportion > 1.0) {
333                        throw new IllegalArgumentException("Proportion must be between 0 and 1");
334                }
335
336                Dataset p1RealContribution = Maths.multiply(p1.getRealPoint(), (1.0 - proportion));
337                Dataset p2RealContribution = Maths.multiply(p2.getRealPoint(), (proportion));
338
339                Dataset realPoint = Maths.add(p1RealContribution, p2RealContribution);
340
341                Dataset p1CoordContribution = Maths.multiply(p1.getCoordPoint(), (1.0 - proportion));
342                Dataset p2CoordContribution = Maths.multiply(p2.getCoordPoint(), (proportion));
343
344                Dataset coordPoint = Maths.add(p1CoordContribution, p2CoordContribution);
345
346                return new InterpolatedPoint(realPoint, coordPoint);
347        }
348
349        /**
350         * Checks to see if 2 points have the same dimensionality
351         * 
352         * @param p1
353         *            Point 1
354         * @param p2
355         *            Point 2
356         * @throws IllegalArgumentException
357         */
358        private static void checkPoints(InterpolatedPoint p1, InterpolatedPoint p2) throws IllegalArgumentException {
359                if (!p1.getCoordPoint().isCompatibleWith(p2.getCoordPoint())) {
360                        throw new IllegalArgumentException("Datasets do not match");
361                }
362        }
363        
364        public static Dataset remap1D(Dataset dataset, Dataset axis, Dataset outputAxis) {
365                Dataset data = DatasetFactory.zeros(DoubleDataset.class, outputAxis.getShapeRef());
366                for(int i = 0; i < outputAxis.getShapeRef()[0]; i++) {
367                        double point = outputAxis.getDouble(i);
368                        double position = getRealPositionAsIndex(axis, point);
369                        if (position >= 0.0) {
370                                data.set(Maths.interpolate(dataset, position), i);
371                        } else {
372                                data.set(Double.NaN,i);
373                        }
374                }
375                
376                return data;
377        }
378
379        // TODO need to make this work with reverse number lists
380        private static double getRealPositionAsIndex(Dataset dataset, double point) {
381                for (int j = 0; j < dataset.getShapeRef()[0]-1; j++) {
382                        double end = dataset.getDouble(j+1);
383                        double start = dataset.getDouble(j);
384                        //TODO could make this check once outside the loop with a minor assumption.
385                        if ( start < end) {
386                                if ((end > point) && (start <= point)) {
387                                        // we have a bounding point
388                                        double proportion = ((point-start)/(end-start));
389                                        return j + proportion;
390                                }
391                        } else {
392                                if ((end < point) && (start >= point)) {
393                                        // we have a bounding point
394                                        double proportion = ((point-start)/(end-start));
395                                        return j + proportion;
396                                }
397                        }
398                }
399                return -1.0;
400        }
401        
402        public static Dataset remapOneAxis(Dataset dataset, int axisIndex, Dataset corrections,
403                        Dataset originalAxisForCorrection, Dataset outputAxis) {
404                int[] stop = dataset.getShape();
405                int[] start = new int[stop.length];
406                int[] step = new int[stop.length];
407                int[] resultSize = new int[stop.length];
408                for (int i = 0 ; i < start.length; i++) {
409                        start[i] = 0;
410                        step[i] = 1;
411                        resultSize[i] = stop[i];
412                }
413                
414                resultSize[axisIndex] = outputAxis.getShapeRef()[0];
415                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize);
416                
417                step[axisIndex] = dataset.getShapeRef()[axisIndex];
418                IndexIterator iter = dataset.getSliceIterator(start, stop, step);
419                
420                int[] pos = iter.getPos();
421                int[] posEnd = new int[pos.length];
422                while (iter.hasNext()){
423                        for (int i = 0 ; i < posEnd.length; i++) {
424                                posEnd[i] = pos[i]+1;
425                        }
426                        posEnd[axisIndex] = stop[axisIndex];
427                        // get the dataset
428                        Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze();
429                        int[] correctionPos = new int[pos.length-1];
430                        int index = 0;
431                        for(int j = 0; j < pos.length; j++) {
432                                if (j != axisIndex) {
433                                        correctionPos[index] = pos[j];
434                                        index++;
435                                }
436                        }
437                        Dataset axis = Maths.subtract(originalAxisForCorrection,corrections.getDouble(correctionPos));
438                        Dataset remapped = remap1D(slice,axis,outputAxis);
439
440                        int[] ref = pos.clone();
441
442                        for (int k = 0; k < result.getShapeRef()[axisIndex]; k++) {
443                                ref[axisIndex] = k;
444                                result.set(remapped.getDouble(k), ref);
445                        }
446                }
447                
448                return result;
449        }
450        
451        
452        public static Dataset remapAxis(Dataset dataset, int axisIndex, Dataset originalAxisForCorrection, Dataset outputAxis) {
453                if (!dataset.isCompatibleWith(originalAxisForCorrection)) {
454                        throw new IllegalArgumentException("Datasets must be of the same shape");
455                }
456                
457                int[] stop = dataset.getShapeRef();
458                int[] start = new int[stop.length];
459                int[] step = new int[stop.length];
460                int[] resultSize = new int[stop.length];
461                for (int i = 0 ; i < start.length; i++) {
462                        start[i] = 0;
463                        step[i] = 1;
464                        resultSize[i] = stop[i];
465                }
466                
467                resultSize[axisIndex] = outputAxis.getShapeRef()[0];
468                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize);
469                
470                step[axisIndex] = dataset.getShapeRef()[axisIndex];
471                IndexIterator iter = dataset.getSliceIterator(start, stop, step);
472                
473                int[] pos = iter.getPos();
474                int[] posEnd = new int[pos.length];
475                while (iter.hasNext()){
476                        for (int i = 0 ; i < posEnd.length; i++) {
477                                posEnd[i] = pos[i]+1;
478                        }
479                        posEnd[axisIndex] = stop[axisIndex];
480                        
481                        // get the dataset
482                        Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze();
483                        Dataset axis = originalAxisForCorrection.getSlice(pos, posEnd, null).squeeze();
484                        
485                        Dataset remapped = remap1D(slice,axis,outputAxis);
486
487                        int[] ref = pos.clone();
488
489                        for (int k = 0; k < result.shape[axisIndex]; k++) {
490                                ref[axisIndex] = k;
491                                result.set(remapped.getDouble(k), ref);
492                        }
493                }
494                
495                return result;
496        }
497
498        public static Dataset regrid(Dataset data, Dataset x, Dataset y, Dataset gridX, Dataset gridY) {
499                
500                // apply X then Y regridding
501                Dataset result = remapAxis(data,1,x,gridX);
502                result = remapAxis(result,0,y,gridY);
503                
504                return result;
505        }
506}