/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan.cuda;

import org.apache.commons.lang.StringUtils;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;

public class Unary
extends CodeTemplate {
    @Override
    public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
        if (LibMatrixNative.isSinglePrecision()) {
            switch (type) {
                case ROW_SUMS: 
                case ROW_SUMSQS: 
                case ROW_MINS: 
                case ROW_MAXS: 
                case ROW_MEANS: 
                case ROW_COUNTNNZS: {
                    String vectName = StringUtils.capitalize((String)type.name().substring(4, type.name().length() - 1).toLowerCase());
                    return sparse ? "\tT %TMP% = LibSpoofPrimitives.vect" + vectName + "(%IN1v%, %IN1i%, %POS1%, alen, len);\n" : "\tT %TMP% = LibSpoofPrimitives.vect" + vectName + "(%IN1%, %POS1%, %LEN%);\n";
                }
                case VECT_EXP: 
                case VECT_POW2: 
                case VECT_MULT2: 
                case VECT_SQRT: 
                case VECT_LOG: 
                case VECT_ABS: 
                case VECT_ROUND: 
                case VECT_CEIL: 
                case VECT_FLOOR: 
                case VECT_SIGN: 
                case VECT_SIN: 
                case VECT_COS: 
                case VECT_TAN: 
                case VECT_ASIN: 
                case VECT_ACOS: 
                case VECT_ATAN: 
                case VECT_SINH: 
                case VECT_COSH: 
                case VECT_TANH: 
                case VECT_CUMSUM: 
                case VECT_CUMMIN: 
                case VECT_CUMMAX: 
                case VECT_SPROP: 
                case VECT_SIGMOID: {
                    String vectName = type.getVectorPrimitiveName();
                    return sparse ? "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" : "\tT[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %POS1%, %LEN%);\n";
                }
                case EXP: {
                    return "\tT %TMP% = expf(%IN1%);\n";
                }
                case LOOKUP_R: {
                    return sparse ? "\tT %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" : "\tT %TMP% = getValue(%IN1%, rix);\n";
                }
                case LOOKUP_C: {
                    return "\tT %TMP% = getValue(%IN1%, n, 0, cix);\n";
                }
                case LOOKUP_RC: {
                    return "\tT %TMP% = getValue(%IN1%, n, rix, cix);\n";
                }
                case LOOKUP0: {
                    return "\tT %TMP% = %IN1%[0];\n";
                }
                case POW2: {
                    return "\tT %TMP% = %IN1% * %IN1%;\n";
                }
                case MULT2: {
                    return "\tT %TMP% = %IN1% + %IN1%;\n";
                }
                case ABS: {
                    return "\tT %TMP% = fabsf(%IN1%);\n";
                }
                case SIN: {
                    return "\tT %TMP% = sinf(%IN1%);\n";
                }
                case COS: {
                    return "\tT %TMP% = cosf(%IN1%);\n";
                }
                case TAN: {
                    return "\tT %TMP% = tanf(%IN1%);\n";
                }
                case ASIN: {
                    return "\tT %TMP% = asinf(%IN1%);\n";
                }
                case ACOS: {
                    return "\tT %TMP% = acosf(%IN1%);\n";
                }
                case ATAN: {
                    return "\tT %TMP% = atanf(%IN1%);\n";
                }
                case SINH: {
                    return "\tT %TMP% = sinhf(%IN1%);\n";
                }
                case COSH: {
                    return "\tT %TMP% = coshf(%IN1%);\n";
                }
                case TANH: {
                    return "\tT %TMP% = tanhf(%IN1%);\n";
                }
                case SIGN: {
                    return "\tT %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
                }
                case SQRT: {
                    return "\tT %TMP% = sqrtf(%IN1%);\n";
                }
                case LOG: {
                    return "\tT %TMP% = logf(%IN1%);\n";
                }
                case ROUND: {
                    return "\tT %TMP% = roundf(%IN1%);\n";
                }
                case CEIL: {
                    return "\tT %TMP% = ceilf(%IN1%);\n";
                }
                case FLOOR: {
                    return "\tT %TMP% = floorf(%IN1%);\n";
                }
                case SPROP: {
                    return "\tT %TMP% = %IN1% * (1 - %IN1%);\n";
                }
                case SIGMOID: {
                    return "\tT %TMP% = 1 / (1 + expf(-%IN1%));\n";
                }
                case LOG_NZ: {
                    return "\tT %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
                }
            }
            throw new RuntimeException("Invalid unary type: " + this.toString());
        }
        switch (type) {
            case ROW_SUMS: 
            case ROW_SUMSQS: 
            case ROW_MINS: 
            case ROW_MAXS: 
            case ROW_MEANS: 
            case ROW_COUNTNNZS: {
                String vectName = StringUtils.capitalize((String)type.name().substring(4, type.name().length() - 1).toLowerCase());
                return sparse ? "\t\tT %TMP% = vect" + vectName + "(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n" : "\t\tT %TMP% = vect" + vectName + "(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
            }
            case VECT_EXP: 
            case VECT_POW2: 
            case VECT_MULT2: 
            case VECT_SQRT: 
            case VECT_LOG: 
            case VECT_ABS: 
            case VECT_ROUND: 
            case VECT_CEIL: 
            case VECT_FLOOR: 
            case VECT_SIGN: 
            case VECT_SIN: 
            case VECT_COS: 
            case VECT_TAN: 
            case VECT_ASIN: 
            case VECT_ACOS: 
            case VECT_ATAN: 
            case VECT_SINH: 
            case VECT_COSH: 
            case VECT_TANH: 
            case VECT_CUMSUM: 
            case VECT_CUMMIN: 
            case VECT_CUMMAX: 
            case VECT_SPROP: 
            case VECT_SIGMOID: {
                String vectName = type.getVectorPrimitiveName();
                return sparse ? "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "\t\tVector<T>& %TMP% = vect" + vectName + "Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
            }
            case EXP: {
                return "\tT %TMP% = exp(%IN1%);\n";
            }
            case LOOKUP_R: {
                return sparse ? "\tT %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" : "\t\tT %TMP% = %IN1%.val(rix);\n";
            }
            case LOOKUP_C: {
                return "\tT %TMP% = getValue(%IN1%, n, 0, cix);\n";
            }
            case LOOKUP_RC: {
                return "\tT %TMP% = getValue(%IN1%, n, rix, cix);\n";
            }
            case LOOKUP0: {
                return "\tT %TMP% = %IN1%[0];\n";
            }
            case POW2: {
                return "\tT %TMP% = %IN1% * %IN1%;\n";
            }
            case MULT2: {
                return "\tT %TMP% = %IN1% + %IN1%;\n";
            }
            case ABS: {
                return "\tT %TMP% = fabs(%IN1%);\n";
            }
            case SIN: {
                return "\tT %TMP% = sin(%IN1%);\n";
            }
            case COS: {
                return "\tT %TMP% = cos(%IN1%);\n";
            }
            case TAN: {
                return "\tT %TMP% = tan(%IN1%);\n";
            }
            case ASIN: {
                return "\tT %TMP% = asin(%IN1%);\n";
            }
            case ACOS: {
                return "\tT %TMP% = acos(%IN1%);\n";
            }
            case ATAN: {
                return "\tT %TMP% = atan(%IN1%);\n";
            }
            case SINH: {
                return "\tT %TMP% = sinh(%IN1%);\n";
            }
            case COSH: {
                return "\tT %TMP% = cosh(%IN1%);\n";
            }
            case TANH: {
                return "\tT %TMP% = tanh(%IN1%);\n";
            }
            case SIGN: {
                return "\tT %TMP% = signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
            }
            case SQRT: {
                return "\tT %TMP% = sqrt(%IN1%);\n";
            }
            case LOG: {
                return "\t\tT %TMP% = log(%IN1%);\n";
            }
            case ROUND: {
                return "\tT %TMP% = round(%IN1%);\n";
            }
            case CEIL: {
                return "\tT %TMP% = ceil(%IN1%);\n";
            }
            case FLOOR: {
                return "\tT %TMP% = floor(%IN1%);\n";
            }
            case SPROP: {
                return "\tT %TMP% = %IN1% * (1 - %IN1%);\n";
            }
            case SIGMOID: {
                return "\tT %TMP% = 1 / (1 + exp(-%IN1%));\n";
            }
            case LOG_NZ: {
                return "\tT %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
            }
        }
        throw new RuntimeException("Invalid unary type: " + this.toString());
    }
}

