//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// This pass reduces the VL where possible at the MI level, before VSETVLI
// instructions are inserted.
//
// The purpose of this optimization is to make the VL argument, for instructions
// that have a VL argument, as small as possible. This is implemented by
// visiting each instruction in reverse order and checking that if it has a VL
// argument, whether the VL can be reduced.
//
//===---------------------------------------------------------------------===//

#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/InitializePasses.h"

using namespace llvm;

#define DEBUG_TYPE "riscv-vl-optimizer"
#define PASS_NAME "RISC-V VL Optimizer"

namespace {

class RISCVVLOptimizer : public MachineFunctionPass {
  const MachineRegisterInfo *MRI;
  const MachineDominatorTree *MDT;

public:
  static char ID;

  RISCVVLOptimizer() : MachineFunctionPass(ID) {}

  bool runOnMachineFunction(MachineFunction &MF) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
    AU.addRequired<MachineDominatorTreeWrapperPass>();
    MachineFunctionPass::getAnalysisUsage(AU);
  }

  StringRef getPassName() const override { return PASS_NAME; }

private:
  std::optional<MachineOperand>
  getMinimumVLForUser(const MachineOperand &UserOp) const;
  /// Returns the largest common VL MachineOperand that may be used to optimize
  /// MI. Returns std::nullopt if it failed to find a suitable VL.
  std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const;
  bool tryReduceVL(MachineInstr &MI) const;
  bool isCandidate(const MachineInstr &MI) const;

  /// For a given instruction, records what elements of it are demanded by
  /// downstream users.
  DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
};

/// Represents the EMUL and EEW of a MachineOperand.
struct OperandInfo {
  // Represent as 1,2,4,8, ... and fractional indicator. This is because
  // EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
  // For example, a mask operand can have an EMUL less than MF8.
  std::optional<std::pair<unsigned, bool>> EMUL;

  unsigned Log2EEW;

  OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
      : EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}

  OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
      : EMUL(EMUL), Log2EEW(Log2EEW) {}

  OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}

  OperandInfo() = delete;

  static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
    return A.Log2EEW == B.Log2EEW && A.EMUL == B.EMUL;
  }

  static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
    return A.Log2EEW == B.Log2EEW;
  }

  void print(raw_ostream &OS) const {
    if (EMUL) {
      OS << "EMUL: m";
      if (EMUL->second)
        OS << "f";
      OS << EMUL->first;
    } else
      OS << "EMUL: unknown\n";
    OS << ", EEW: " << (1 << Log2EEW);
  }
};

} // end anonymous namespace

char RISCVVLOptimizer::ID = 0;
INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)

FunctionPass *llvm::createRISCVVLOptimizerPass() {
  return new RISCVVLOptimizer();
}

/// Return true if R is a physical or virtual vector register, false otherwise.
static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
  if (R.isPhysical())
    return RISCV::VRRegClass.contains(R);
  const TargetRegisterClass *RC = MRI->getRegClass(R);
  return RISCVRI::isVRegClass(RC->TSFlags);
}

LLVM_ATTRIBUTE_UNUSED
static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
  OI.print(OS);
  return OS;
}

LLVM_ATTRIBUTE_UNUSED
static raw_ostream &operator<<(raw_ostream &OS,
                               const std::optional<OperandInfo> &OI) {
  if (OI)
    OI->print(OS);
  else
    OS << "nullopt";
  return OS;
}

/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
/// SEW are from the TSFlags of MI.
static std::pair<unsigned, bool>
getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
  RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
  auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
  unsigned MILog2SEW =
      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();

  // Mask instructions will have 0 as the SEW operand. But the LMUL of these
  // instructions is calculated is as if the SEW operand was 3 (e8).
  if (MILog2SEW == 0)
    MILog2SEW = 3;

  unsigned MISEW = 1 << MILog2SEW;

  unsigned EEW = 1 << Log2EEW;
  // Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
  // to put fraction in simplest form.
  unsigned Num = EEW, Denom = MISEW;
  int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
                               : std::gcd(Num * MILMUL, Denom);
  Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
  Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
  return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
}

/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
/// SEW comes from TSFlags of MI.
static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
                                              const MachineInstr &MI,
                                              const MachineOperand &MO) {
  unsigned MILog2SEW =
      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();

  if (MO.getOperandNo() == 0)
    return MILog2SEW;

  unsigned MISEW = 1 << MILog2SEW;
  unsigned EEW = MISEW / Factor;
  unsigned Log2EEW = Log2_32(EEW);

  return Log2EEW;
}

/// Check whether MO is a mask operand of MI.
static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
                          const MachineRegisterInfo *MRI) {

  if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
    return false;

  const MCInstrDesc &Desc = MI.getDesc();
  return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
}

static std::optional<unsigned>
getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
  const MachineInstr &MI = *MO.getParent();
  const RISCVVPseudosTable::PseudoInfo *RVV =
      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
  assert(RVV && "Could not find MI in PseudoTable");

  // MI has a SEW associated with it. The RVV specification defines
  // the EEW of each operand and definition in relation to MI.SEW.
  unsigned MILog2SEW =
      MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();

  const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
  const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags);

  bool IsMODef = MO.getOperandNo() == 0 ||
                 (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());

  // All mask operands have EEW=1
  if (isMaskOperand(MI, MO, MRI))
    return 0;

  // switch against BaseInstr to reduce number of cases that need to be
  // considered.
  switch (RVV->BaseInstr) {

  // 6. Configuration-Setting Instructions
  // Configuration setting instructions do not read or write vector registers
  case RISCV::VSETIVLI:
  case RISCV::VSETVL:
  case RISCV::VSETVLI:
    llvm_unreachable("Configuration setting instructions do not read or write "
                     "vector registers");

  // Vector Loads and Stores
  // Vector Unit-Stride Instructions
  // Vector Strided Instructions
  /// Dest EEW encoded in the instruction
  case RISCV::VLM_V:
  case RISCV::VSM_V:
    return 0;
  case RISCV::VLE8_V:
  case RISCV::VSE8_V:
  case RISCV::VLSE8_V:
  case RISCV::VSSE8_V:
    return 3;
  case RISCV::VLE16_V:
  case RISCV::VSE16_V:
  case RISCV::VLSE16_V:
  case RISCV::VSSE16_V:
    return 4;
  case RISCV::VLE32_V:
  case RISCV::VSE32_V:
  case RISCV::VLSE32_V:
  case RISCV::VSSE32_V:
    return 5;
  case RISCV::VLE64_V:
  case RISCV::VSE64_V:
  case RISCV::VLSE64_V:
  case RISCV::VSSE64_V:
    return 6;

  // Vector Indexed Instructions
  // vs(o|u)xei<eew>.v
  // Dest/Data (operand 0) EEW=SEW.  Source EEW=<eew>.
  case RISCV::VLUXEI8_V:
  case RISCV::VLOXEI8_V:
  case RISCV::VSUXEI8_V:
  case RISCV::VSOXEI8_V: {
    if (MO.getOperandNo() == 0)
      return MILog2SEW;
    return 3;
  }
  case RISCV::VLUXEI16_V:
  case RISCV::VLOXEI16_V:
  case RISCV::VSUXEI16_V:
  case RISCV::VSOXEI16_V: {
    if (MO.getOperandNo() == 0)
      return MILog2SEW;
    return 4;
  }
  case RISCV::VLUXEI32_V:
  case RISCV::VLOXEI32_V:
  case RISCV::VSUXEI32_V:
  case RISCV::VSOXEI32_V: {
    if (MO.getOperandNo() == 0)
      return MILog2SEW;
    return 5;
  }
  case RISCV::VLUXEI64_V:
  case RISCV::VLOXEI64_V:
  case RISCV::VSUXEI64_V:
  case RISCV::VSOXEI64_V: {
    if (MO.getOperandNo() == 0)
      return MILog2SEW;
    return 6;
  }

  // Vector Integer Arithmetic Instructions
  // Vector Single-Width Integer Add and Subtract
  case RISCV::VADD_VI:
  case RISCV::VADD_VV:
  case RISCV::VADD_VX:
  case RISCV::VSUB_VV:
  case RISCV::VSUB_VX:
  case RISCV::VRSUB_VI:
  case RISCV::VRSUB_VX:
  // Vector Bitwise Logical Instructions
  // Vector Single-Width Shift Instructions
  // EEW=SEW.
  case RISCV::VAND_VI:
  case RISCV::VAND_VV:
  case RISCV::VAND_VX:
  case RISCV::VOR_VI:
  case RISCV::VOR_VV:
  case RISCV::VOR_VX:
  case RISCV::VXOR_VI:
  case RISCV::VXOR_VV:
  case RISCV::VXOR_VX:
  case RISCV::VSLL_VI:
  case RISCV::VSLL_VV:
  case RISCV::VSLL_VX:
  case RISCV::VSRL_VI:
  case RISCV::VSRL_VV:
  case RISCV::VSRL_VX:
  case RISCV::VSRA_VI:
  case RISCV::VSRA_VV:
  case RISCV::VSRA_VX:
  // Vector Integer Min/Max Instructions
  // EEW=SEW.
  case RISCV::VMINU_VV:
  case RISCV::VMINU_VX:
  case RISCV::VMIN_VV:
  case RISCV::VMIN_VX:
  case RISCV::VMAXU_VV:
  case RISCV::VMAXU_VX:
  case RISCV::VMAX_VV:
  case RISCV::VMAX_VX:
  // Vector Single-Width Integer Multiply Instructions
  // Source and Dest EEW=SEW.
  case RISCV::VMUL_VV:
  case RISCV::VMUL_VX:
  case RISCV::VMULH_VV:
  case RISCV::VMULH_VX:
  case RISCV::VMULHU_VV:
  case RISCV::VMULHU_VX:
  case RISCV::VMULHSU_VV:
  case RISCV::VMULHSU_VX:
  // Vector Integer Divide Instructions
  // EEW=SEW.
  case RISCV::VDIVU_VV:
  case RISCV::VDIVU_VX:
  case RISCV::VDIV_VV:
  case RISCV::VDIV_VX:
  case RISCV::VREMU_VV:
  case RISCV::VREMU_VX:
  case RISCV::VREM_VV:
  case RISCV::VREM_VX:
  // Vector Single-Width Integer Multiply-Add Instructions
  // EEW=SEW.
  case RISCV::VMACC_VV:
  case RISCV::VMACC_VX:
  case RISCV::VNMSAC_VV:
  case RISCV::VNMSAC_VX:
  case RISCV::VMADD_VV:
  case RISCV::VMADD_VX:
  case RISCV::VNMSUB_VV:
  case RISCV::VNMSUB_VX:
  // Vector Integer Merge Instructions
  // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
  // EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
  // before this switch.
  case RISCV::VMERGE_VIM:
  case RISCV::VMERGE_VVM:
  case RISCV::VMERGE_VXM:
  case RISCV::VADC_VIM:
  case RISCV::VADC_VVM:
  case RISCV::VADC_VXM:
  case RISCV::VSBC_VVM:
  case RISCV::VSBC_VXM:
  // Vector Integer Move Instructions
  // Vector Fixed-Point Arithmetic Instructions
  // Vector Single-Width Saturating Add and Subtract
  // Vector Single-Width Averaging Add and Subtract
  // EEW=SEW.
  case RISCV::VMV_V_I:
  case RISCV::VMV_V_V:
  case RISCV::VMV_V_X:
  case RISCV::VSADDU_VI:
  case RISCV::VSADDU_VV:
  case RISCV::VSADDU_VX:
  case RISCV::VSADD_VI:
  case RISCV::VSADD_VV:
  case RISCV::VSADD_VX:
  case RISCV::VSSUBU_VV:
  case RISCV::VSSUBU_VX:
  case RISCV::VSSUB_VV:
  case RISCV::VSSUB_VX:
  case RISCV::VAADDU_VV:
  case RISCV::VAADDU_VX:
  case RISCV::VAADD_VV:
  case RISCV::VAADD_VX:
  case RISCV::VASUBU_VV:
  case RISCV::VASUBU_VX:
  case RISCV::VASUB_VV:
  case RISCV::VASUB_VX:
  // Vector Single-Width Fractional Multiply with Rounding and Saturation
  // EEW=SEW. The instruction produces 2*SEW product internally but
  // saturates to fit into SEW bits.
  case RISCV::VSMUL_VV:
  case RISCV::VSMUL_VX:
  // Vector Single-Width Scaling Shift Instructions
  // EEW=SEW.
  case RISCV::VSSRL_VI:
  case RISCV::VSSRL_VV:
  case RISCV::VSSRL_VX:
  case RISCV::VSSRA_VI:
  case RISCV::VSSRA_VV:
  case RISCV::VSSRA_VX:
  // Vector Permutation Instructions
  // Integer Scalar Move Instructions
  // Floating-Point Scalar Move Instructions
  // EEW=SEW.
  case RISCV::VMV_X_S:
  case RISCV::VMV_S_X:
  case RISCV::VFMV_F_S:
  case RISCV::VFMV_S_F:
  // Vector Slide Instructions
  // EEW=SEW.
  case RISCV::VSLIDEUP_VI:
  case RISCV::VSLIDEUP_VX:
  case RISCV::VSLIDEDOWN_VI:
  case RISCV::VSLIDEDOWN_VX:
  case RISCV::VSLIDE1UP_VX:
  case RISCV::VFSLIDE1UP_VF:
  case RISCV::VSLIDE1DOWN_VX:
  case RISCV::VFSLIDE1DOWN_VF:
  // Vector Register Gather Instructions
  // EEW=SEW. For mask operand, EEW=1.
  case RISCV::VRGATHER_VI:
  case RISCV::VRGATHER_VV:
  case RISCV::VRGATHER_VX:
  // Vector Compress Instruction
  // EEW=SEW.
  case RISCV::VCOMPRESS_VM:
  // Vector Element Index Instruction
  case RISCV::VID_V:
  // Vector Single-Width Floating-Point Add/Subtract Instructions
  case RISCV::VFADD_VF:
  case RISCV::VFADD_VV:
  case RISCV::VFSUB_VF:
  case RISCV::VFSUB_VV:
  case RISCV::VFRSUB_VF:
  // Vector Single-Width Floating-Point Multiply/Divide Instructions
  case RISCV::VFMUL_VF:
  case RISCV::VFMUL_VV:
  case RISCV::VFDIV_VF:
  case RISCV::VFDIV_VV:
  case RISCV::VFRDIV_VF:
  // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
  case RISCV::VFMACC_VV:
  case RISCV::VFMACC_VF:
  case RISCV::VFNMACC_VV:
  case RISCV::VFNMACC_VF:
  case RISCV::VFMSAC_VV:
  case RISCV::VFMSAC_VF:
  case RISCV::VFNMSAC_VV:
  case RISCV::VFNMSAC_VF:
  case RISCV::VFMADD_VV:
  case RISCV::VFMADD_VF:
  case RISCV::VFNMADD_VV:
  case RISCV::VFNMADD_VF:
  case RISCV::VFMSUB_VV:
  case RISCV::VFMSUB_VF:
  case RISCV::VFNMSUB_VV:
  case RISCV::VFNMSUB_VF:
  // Vector Floating-Point Square-Root Instruction
  case RISCV::VFSQRT_V:
  // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
  case RISCV::VFRSQRT7_V:
  // Vector Floating-Point Reciprocal Estimate Instruction
  case RISCV::VFREC7_V:
  // Vector Floating-Point MIN/MAX Instructions
  case RISCV::VFMIN_VF:
  case RISCV::VFMIN_VV:
  case RISCV::VFMAX_VF:
  case RISCV::VFMAX_VV:
  // Vector Floating-Point Sign-Injection Instructions
  case RISCV::VFSGNJ_VF:
  case RISCV::VFSGNJ_VV:
  case RISCV::VFSGNJN_VV:
  case RISCV::VFSGNJN_VF:
  case RISCV::VFSGNJX_VF:
  case RISCV::VFSGNJX_VV:
  // Vector Floating-Point Classify Instruction
  case RISCV::VFCLASS_V:
  // Vector Floating-Point Move Instruction
  case RISCV::VFMV_V_F:
  // Single-Width Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFCVT_XU_F_V:
  case RISCV::VFCVT_X_F_V:
  case RISCV::VFCVT_RTZ_XU_F_V:
  case RISCV::VFCVT_RTZ_X_F_V:
  case RISCV::VFCVT_F_XU_V:
  case RISCV::VFCVT_F_X_V:
  // Vector Floating-Point Merge Instruction
  case RISCV::VFMERGE_VFM:
  // Vector count population in mask vcpop.m
  // vfirst find-first-set mask bit
  case RISCV::VCPOP_M:
  case RISCV::VFIRST_M:
    return MILog2SEW;

  // Vector Widening Integer Add/Subtract
  // Def uses EEW=2*SEW . Operands use EEW=SEW.
  case RISCV::VWADDU_VV:
  case RISCV::VWADDU_VX:
  case RISCV::VWSUBU_VV:
  case RISCV::VWSUBU_VX:
  case RISCV::VWADD_VV:
  case RISCV::VWADD_VX:
  case RISCV::VWSUB_VV:
  case RISCV::VWSUB_VX:
  case RISCV::VWSLL_VI:
  case RISCV::VWSLL_VX:
  case RISCV::VWSLL_VV:
  // Vector Widening Integer Multiply Instructions
  // Destination EEW=2*SEW. Source EEW=SEW.
  case RISCV::VWMUL_VV:
  case RISCV::VWMUL_VX:
  case RISCV::VWMULSU_VV:
  case RISCV::VWMULSU_VX:
  case RISCV::VWMULU_VV:
  case RISCV::VWMULU_VX:
  // Vector Widening Integer Multiply-Add Instructions
  // Destination EEW=2*SEW. Source EEW=SEW.
  // A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
  // is then added to the 2*SEW-bit Dest. These instructions never have a
  // passthru operand.
  case RISCV::VWMACCU_VV:
  case RISCV::VWMACCU_VX:
  case RISCV::VWMACC_VV:
  case RISCV::VWMACC_VX:
  case RISCV::VWMACCSU_VV:
  case RISCV::VWMACCSU_VX:
  case RISCV::VWMACCUS_VX:
  // Vector Widening Floating-Point Fused Multiply-Add Instructions
  case RISCV::VFWMACC_VF:
  case RISCV::VFWMACC_VV:
  case RISCV::VFWNMACC_VF:
  case RISCV::VFWNMACC_VV:
  case RISCV::VFWMSAC_VF:
  case RISCV::VFWMSAC_VV:
  case RISCV::VFWNMSAC_VF:
  case RISCV::VFWNMSAC_VV:
  case RISCV::VFWMACCBF16_VV:
  case RISCV::VFWMACCBF16_VF:
  // Vector Widening Floating-Point Add/Subtract Instructions
  // Dest EEW=2*SEW. Source EEW=SEW.
  case RISCV::VFWADD_VV:
  case RISCV::VFWADD_VF:
  case RISCV::VFWSUB_VV:
  case RISCV::VFWSUB_VF:
  // Vector Widening Floating-Point Multiply
  case RISCV::VFWMUL_VF:
  case RISCV::VFWMUL_VV:
  // Widening Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFWCVT_XU_F_V:
  case RISCV::VFWCVT_X_F_V:
  case RISCV::VFWCVT_RTZ_XU_F_V:
  case RISCV::VFWCVT_RTZ_X_F_V:
  case RISCV::VFWCVT_F_XU_V:
  case RISCV::VFWCVT_F_X_V:
  case RISCV::VFWCVT_F_F_V:
  case RISCV::VFWCVTBF16_F_F_V:
    return IsMODef ? MILog2SEW + 1 : MILog2SEW;

  // Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
  case RISCV::VWADDU_WV:
  case RISCV::VWADDU_WX:
  case RISCV::VWSUBU_WV:
  case RISCV::VWSUBU_WX:
  case RISCV::VWADD_WV:
  case RISCV::VWADD_WX:
  case RISCV::VWSUB_WV:
  case RISCV::VWSUB_WX:
  // Vector Widening Floating-Point Add/Subtract Instructions
  case RISCV::VFWADD_WF:
  case RISCV::VFWADD_WV:
  case RISCV::VFWSUB_WF:
  case RISCV::VFWSUB_WV: {
    bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
                                          : MO.getOperandNo() == 1;
    bool TwoTimes = IsMODef || IsOp1;
    return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
  }

  // Vector Integer Extension
  case RISCV::VZEXT_VF2:
  case RISCV::VSEXT_VF2:
    return getIntegerExtensionOperandEEW(2, MI, MO);
  case RISCV::VZEXT_VF4:
  case RISCV::VSEXT_VF4:
    return getIntegerExtensionOperandEEW(4, MI, MO);
  case RISCV::VZEXT_VF8:
  case RISCV::VSEXT_VF8:
    return getIntegerExtensionOperandEEW(8, MI, MO);

  // Vector Narrowing Integer Right Shift Instructions
  // Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
  case RISCV::VNSRL_WX:
  case RISCV::VNSRL_WI:
  case RISCV::VNSRL_WV:
  case RISCV::VNSRA_WI:
  case RISCV::VNSRA_WV:
  case RISCV::VNSRA_WX:
  // Vector Narrowing Fixed-Point Clip Instructions
  // Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
  case RISCV::VNCLIPU_WI:
  case RISCV::VNCLIPU_WV:
  case RISCV::VNCLIPU_WX:
  case RISCV::VNCLIP_WI:
  case RISCV::VNCLIP_WV:
  case RISCV::VNCLIP_WX:
  // Narrowing Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFNCVT_XU_F_W:
  case RISCV::VFNCVT_X_F_W:
  case RISCV::VFNCVT_RTZ_XU_F_W:
  case RISCV::VFNCVT_RTZ_X_F_W:
  case RISCV::VFNCVT_F_XU_W:
  case RISCV::VFNCVT_F_X_W:
  case RISCV::VFNCVT_F_F_W:
  case RISCV::VFNCVT_ROD_F_F_W:
  case RISCV::VFNCVTBF16_F_F_W: {
    assert(!IsTied);
    bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
    bool TwoTimes = IsOp1;
    return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
  }

  // Vector Mask Instructions
  // Vector Mask-Register Logical Instructions
  // vmsbf.m set-before-first mask bit
  // vmsif.m set-including-first mask bit
  // vmsof.m set-only-first mask bit
  // EEW=1
  // We handle the cases when operand is a v0 mask operand above the switch,
  // but these instructions may use non-v0 mask operands and need to be handled
  // specifically.
  case RISCV::VMAND_MM:
  case RISCV::VMNAND_MM:
  case RISCV::VMANDN_MM:
  case RISCV::VMXOR_MM:
  case RISCV::VMOR_MM:
  case RISCV::VMNOR_MM:
  case RISCV::VMORN_MM:
  case RISCV::VMXNOR_MM:
  case RISCV::VMSBF_M:
  case RISCV::VMSIF_M:
  case RISCV::VMSOF_M: {
    return MILog2SEW;
  }

  // Vector Iota Instruction
  // EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
  // before this switch.
  case RISCV::VIOTA_M: {
    if (IsMODef || MO.getOperandNo() == 1)
      return MILog2SEW;
    return 0;
  }

  // Vector Integer Compare Instructions
  // Dest EEW=1. Source EEW=SEW.
  case RISCV::VMSEQ_VI:
  case RISCV::VMSEQ_VV:
  case RISCV::VMSEQ_VX:
  case RISCV::VMSNE_VI:
  case RISCV::VMSNE_VV:
  case RISCV::VMSNE_VX:
  case RISCV::VMSLTU_VV:
  case RISCV::VMSLTU_VX:
  case RISCV::VMSLT_VV:
  case RISCV::VMSLT_VX:
  case RISCV::VMSLEU_VV:
  case RISCV::VMSLEU_VI:
  case RISCV::VMSLEU_VX:
  case RISCV::VMSLE_VV:
  case RISCV::VMSLE_VI:
  case RISCV::VMSLE_VX:
  case RISCV::VMSGTU_VI:
  case RISCV::VMSGTU_VX:
  case RISCV::VMSGT_VI:
  case RISCV::VMSGT_VX:
  // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
  // Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
  case RISCV::VMADC_VIM:
  case RISCV::VMADC_VVM:
  case RISCV::VMADC_VXM:
  case RISCV::VMSBC_VVM:
  case RISCV::VMSBC_VXM:
  // Dest EEW=1. Source EEW=SEW.
  case RISCV::VMADC_VV:
  case RISCV::VMADC_VI:
  case RISCV::VMADC_VX:
  case RISCV::VMSBC_VV:
  case RISCV::VMSBC_VX:
  // 13.13. Vector Floating-Point Compare Instructions
  // Dest EEW=1. Source EEW=SEW
  case RISCV::VMFEQ_VF:
  case RISCV::VMFEQ_VV:
  case RISCV::VMFNE_VF:
  case RISCV::VMFNE_VV:
  case RISCV::VMFLT_VF:
  case RISCV::VMFLT_VV:
  case RISCV::VMFLE_VF:
  case RISCV::VMFLE_VV:
  case RISCV::VMFGT_VF:
  case RISCV::VMFGE_VF: {
    if (IsMODef)
      return 0;
    return MILog2SEW;
  }

  // Vector Reduction Operations
  // Vector Single-Width Integer Reduction Instructions
  case RISCV::VREDAND_VS:
  case RISCV::VREDMAX_VS:
  case RISCV::VREDMAXU_VS:
  case RISCV::VREDMIN_VS:
  case RISCV::VREDMINU_VS:
  case RISCV::VREDOR_VS:
  case RISCV::VREDSUM_VS:
  case RISCV::VREDXOR_VS:
  // Vector Single-Width Floating-Point Reduction Instructions
  case RISCV::VFREDMAX_VS:
  case RISCV::VFREDMIN_VS:
  case RISCV::VFREDOSUM_VS:
  case RISCV::VFREDUSUM_VS: {
    return MILog2SEW;
  }

  // Vector Widening Integer Reduction Instructions
  // The Dest and VS1 read only element 0 for the vector register. Return
  // 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
  case RISCV::VWREDSUM_VS:
  case RISCV::VWREDSUMU_VS:
  // Vector Widening Floating-Point Reduction Instructions
  case RISCV::VFWREDOSUM_VS:
  case RISCV::VFWREDUSUM_VS: {
    bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
    return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
  }

  default:
    return std::nullopt;
  }
}

static std::optional<OperandInfo>
getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
  const MachineInstr &MI = *MO.getParent();
  const RISCVVPseudosTable::PseudoInfo *RVV =
      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
  assert(RVV && "Could not find MI in PseudoTable");

  std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
  if (!Log2EEW)
    return std::nullopt;

  switch (RVV->BaseInstr) {
  // Vector Reduction Operations
  // Vector Single-Width Integer Reduction Instructions
  // Vector Widening Integer Reduction Instructions
  // Vector Widening Floating-Point Reduction Instructions
  // The Dest and VS1 only read element 0 of the vector register. Return just
  // the EEW for these.
  case RISCV::VREDAND_VS:
  case RISCV::VREDMAX_VS:
  case RISCV::VREDMAXU_VS:
  case RISCV::VREDMIN_VS:
  case RISCV::VREDMINU_VS:
  case RISCV::VREDOR_VS:
  case RISCV::VREDSUM_VS:
  case RISCV::VREDXOR_VS:
  case RISCV::VWREDSUM_VS:
  case RISCV::VWREDSUMU_VS:
  case RISCV::VFWREDOSUM_VS:
  case RISCV::VFWREDUSUM_VS:
    if (MO.getOperandNo() != 2)
      return OperandInfo(*Log2EEW);
    break;
  };

  // All others have EMUL=EEW/SEW*LMUL
  return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW);
}

/// Return true if this optimization should consider MI for VL reduction. This
/// white-list approach simplifies this optimization for instructions that may
/// have more complex semantics with relation to how it uses VL.
static bool isSupportedInstr(const MachineInstr &MI) {
  const RISCVVPseudosTable::PseudoInfo *RVV =
      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());

  if (!RVV)
    return false;

  switch (RVV->BaseInstr) {
  // Vector Unit-Stride Instructions
  // Vector Strided Instructions
  case RISCV::VLM_V:
  case RISCV::VLE8_V:
  case RISCV::VLSE8_V:
  case RISCV::VLE16_V:
  case RISCV::VLSE16_V:
  case RISCV::VLE32_V:
  case RISCV::VLSE32_V:
  case RISCV::VLE64_V:
  case RISCV::VLSE64_V:
  // Vector Indexed Instructions
  case RISCV::VLUXEI8_V:
  case RISCV::VLOXEI8_V:
  case RISCV::VLUXEI16_V:
  case RISCV::VLOXEI16_V:
  case RISCV::VLUXEI32_V:
  case RISCV::VLOXEI32_V:
  case RISCV::VLUXEI64_V:
  case RISCV::VLOXEI64_V: {
    for (const MachineMemOperand *MMO : MI.memoperands())
      if (MMO->isVolatile())
        return false;
    return true;
  }

  // Vector Single-Width Integer Add and Subtract
  case RISCV::VADD_VI:
  case RISCV::VADD_VV:
  case RISCV::VADD_VX:
  case RISCV::VSUB_VV:
  case RISCV::VSUB_VX:
  case RISCV::VRSUB_VI:
  case RISCV::VRSUB_VX:
  // Vector Bitwise Logical Instructions
  // Vector Single-Width Shift Instructions
  case RISCV::VAND_VI:
  case RISCV::VAND_VV:
  case RISCV::VAND_VX:
  case RISCV::VOR_VI:
  case RISCV::VOR_VV:
  case RISCV::VOR_VX:
  case RISCV::VXOR_VI:
  case RISCV::VXOR_VV:
  case RISCV::VXOR_VX:
  case RISCV::VSLL_VI:
  case RISCV::VSLL_VV:
  case RISCV::VSLL_VX:
  case RISCV::VSRL_VI:
  case RISCV::VSRL_VV:
  case RISCV::VSRL_VX:
  case RISCV::VSRA_VI:
  case RISCV::VSRA_VV:
  case RISCV::VSRA_VX:
  // Vector Widening Integer Add/Subtract
  case RISCV::VWADDU_VV:
  case RISCV::VWADDU_VX:
  case RISCV::VWSUBU_VV:
  case RISCV::VWSUBU_VX:
  case RISCV::VWADD_VV:
  case RISCV::VWADD_VX:
  case RISCV::VWSUB_VV:
  case RISCV::VWSUB_VX:
  case RISCV::VWADDU_WV:
  case RISCV::VWADDU_WX:
  case RISCV::VWSUBU_WV:
  case RISCV::VWSUBU_WX:
  case RISCV::VWADD_WV:
  case RISCV::VWADD_WX:
  case RISCV::VWSUB_WV:
  case RISCV::VWSUB_WX:
  // Vector Integer Extension
  case RISCV::VZEXT_VF2:
  case RISCV::VSEXT_VF2:
  case RISCV::VZEXT_VF4:
  case RISCV::VSEXT_VF4:
  case RISCV::VZEXT_VF8:
  case RISCV::VSEXT_VF8:
  // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
  // FIXME: Add support
  case RISCV::VMADC_VV:
  case RISCV::VMADC_VI:
  case RISCV::VMADC_VX:
  case RISCV::VMSBC_VV:
  case RISCV::VMSBC_VX:
  // Vector Narrowing Integer Right Shift Instructions
  case RISCV::VNSRL_WX:
  case RISCV::VNSRL_WI:
  case RISCV::VNSRL_WV:
  case RISCV::VNSRA_WI:
  case RISCV::VNSRA_WV:
  case RISCV::VNSRA_WX:
  // Vector Integer Compare Instructions
  case RISCV::VMSEQ_VI:
  case RISCV::VMSEQ_VV:
  case RISCV::VMSEQ_VX:
  case RISCV::VMSNE_VI:
  case RISCV::VMSNE_VV:
  case RISCV::VMSNE_VX:
  case RISCV::VMSLTU_VV:
  case RISCV::VMSLTU_VX:
  case RISCV::VMSLT_VV:
  case RISCV::VMSLT_VX:
  case RISCV::VMSLEU_VV:
  case RISCV::VMSLEU_VI:
  case RISCV::VMSLEU_VX:
  case RISCV::VMSLE_VV:
  case RISCV::VMSLE_VI:
  case RISCV::VMSLE_VX:
  case RISCV::VMSGTU_VI:
  case RISCV::VMSGTU_VX:
  case RISCV::VMSGT_VI:
  case RISCV::VMSGT_VX:
  // Vector Integer Min/Max Instructions
  case RISCV::VMINU_VV:
  case RISCV::VMINU_VX:
  case RISCV::VMIN_VV:
  case RISCV::VMIN_VX:
  case RISCV::VMAXU_VV:
  case RISCV::VMAXU_VX:
  case RISCV::VMAX_VV:
  case RISCV::VMAX_VX:
  // Vector Single-Width Integer Multiply Instructions
  case RISCV::VMUL_VV:
  case RISCV::VMUL_VX:
  case RISCV::VMULH_VV:
  case RISCV::VMULH_VX:
  case RISCV::VMULHU_VV:
  case RISCV::VMULHU_VX:
  case RISCV::VMULHSU_VV:
  case RISCV::VMULHSU_VX:
  // Vector Integer Divide Instructions
  case RISCV::VDIVU_VV:
  case RISCV::VDIVU_VX:
  case RISCV::VDIV_VV:
  case RISCV::VDIV_VX:
  case RISCV::VREMU_VV:
  case RISCV::VREMU_VX:
  case RISCV::VREM_VV:
  case RISCV::VREM_VX:
  // Vector Widening Integer Multiply Instructions
  case RISCV::VWMUL_VV:
  case RISCV::VWMUL_VX:
  case RISCV::VWMULSU_VV:
  case RISCV::VWMULSU_VX:
  case RISCV::VWMULU_VV:
  case RISCV::VWMULU_VX:
  // Vector Single-Width Integer Multiply-Add Instructions
  case RISCV::VMACC_VV:
  case RISCV::VMACC_VX:
  case RISCV::VNMSAC_VV:
  case RISCV::VNMSAC_VX:
  case RISCV::VMADD_VV:
  case RISCV::VMADD_VX:
  case RISCV::VNMSUB_VV:
  case RISCV::VNMSUB_VX:
  // Vector Integer Merge Instructions
  case RISCV::VMERGE_VIM:
  case RISCV::VMERGE_VVM:
  case RISCV::VMERGE_VXM:
  // Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
  case RISCV::VADC_VIM:
  case RISCV::VADC_VVM:
  case RISCV::VADC_VXM:
  // Vector Widening Integer Multiply-Add Instructions
  case RISCV::VWMACCU_VV:
  case RISCV::VWMACCU_VX:
  case RISCV::VWMACC_VV:
  case RISCV::VWMACC_VX:
  case RISCV::VWMACCSU_VV:
  case RISCV::VWMACCSU_VX:
  case RISCV::VWMACCUS_VX:
  // Vector Integer Merge Instructions
  // FIXME: Add support
  // Vector Integer Move Instructions
  // FIXME: Add support
  case RISCV::VMV_V_I:
  case RISCV::VMV_V_X:
  case RISCV::VMV_V_V:
  // Vector Single-Width Saturating Add and Subtract
  case RISCV::VSADDU_VV:
  case RISCV::VSADDU_VX:
  case RISCV::VSADDU_VI:
  case RISCV::VSADD_VV:
  case RISCV::VSADD_VX:
  case RISCV::VSADD_VI:
  case RISCV::VSSUBU_VV:
  case RISCV::VSSUBU_VX:
  case RISCV::VSSUB_VV:
  case RISCV::VSSUB_VX:
  // Vector Single-Width Averaging Add and Subtract
  case RISCV::VAADDU_VV:
  case RISCV::VAADDU_VX:
  case RISCV::VAADD_VV:
  case RISCV::VAADD_VX:
  case RISCV::VASUBU_VV:
  case RISCV::VASUBU_VX:
  case RISCV::VASUB_VV:
  case RISCV::VASUB_VX:
  // Vector Single-Width Fractional Multiply with Rounding and Saturation
  case RISCV::VSMUL_VV:
  case RISCV::VSMUL_VX:
  // Vector Single-Width Scaling Shift Instructions
  case RISCV::VSSRL_VV:
  case RISCV::VSSRL_VX:
  case RISCV::VSSRL_VI:
  case RISCV::VSSRA_VV:
  case RISCV::VSSRA_VX:
  case RISCV::VSSRA_VI:
  // Vector Narrowing Fixed-Point Clip Instructions
  case RISCV::VNCLIPU_WV:
  case RISCV::VNCLIPU_WX:
  case RISCV::VNCLIPU_WI:
  case RISCV::VNCLIP_WV:
  case RISCV::VNCLIP_WX:
  case RISCV::VNCLIP_WI:

  // Vector Crypto
  case RISCV::VWSLL_VI:
  case RISCV::VWSLL_VX:
  case RISCV::VWSLL_VV:

  // Vector Mask Instructions
  // Vector Mask-Register Logical Instructions
  // vmsbf.m set-before-first mask bit
  // vmsif.m set-including-first mask bit
  // vmsof.m set-only-first mask bit
  // Vector Iota Instruction
  // Vector Element Index Instruction
  case RISCV::VMAND_MM:
  case RISCV::VMNAND_MM:
  case RISCV::VMANDN_MM:
  case RISCV::VMXOR_MM:
  case RISCV::VMOR_MM:
  case RISCV::VMNOR_MM:
  case RISCV::VMORN_MM:
  case RISCV::VMXNOR_MM:
  case RISCV::VMSBF_M:
  case RISCV::VMSIF_M:
  case RISCV::VMSOF_M:
  case RISCV::VIOTA_M:
  case RISCV::VID_V:
  // Vector Slide Instructions
  case RISCV::VSLIDEUP_VX:
  case RISCV::VSLIDEUP_VI:
  case RISCV::VSLIDEDOWN_VX:
  case RISCV::VSLIDEDOWN_VI:
  case RISCV::VSLIDE1UP_VX:
  case RISCV::VFSLIDE1UP_VF:
  // Vector Single-Width Floating-Point Add/Subtract Instructions
  case RISCV::VFADD_VF:
  case RISCV::VFADD_VV:
  case RISCV::VFSUB_VF:
  case RISCV::VFSUB_VV:
  case RISCV::VFRSUB_VF:
  // Vector Widening Floating-Point Add/Subtract Instructions
  case RISCV::VFWADD_VV:
  case RISCV::VFWADD_VF:
  case RISCV::VFWSUB_VV:
  case RISCV::VFWSUB_VF:
  case RISCV::VFWADD_WF:
  case RISCV::VFWADD_WV:
  case RISCV::VFWSUB_WF:
  case RISCV::VFWSUB_WV:
  // Vector Single-Width Floating-Point Multiply/Divide Instructions
  case RISCV::VFMUL_VF:
  case RISCV::VFMUL_VV:
  case RISCV::VFDIV_VF:
  case RISCV::VFDIV_VV:
  case RISCV::VFRDIV_VF:
  // Vector Widening Floating-Point Multiply
  case RISCV::VFWMUL_VF:
  case RISCV::VFWMUL_VV:
  // Vector Single-Width Floating-Point Fused Multiply-Add Instructions
  case RISCV::VFMACC_VV:
  case RISCV::VFMACC_VF:
  case RISCV::VFNMACC_VV:
  case RISCV::VFNMACC_VF:
  case RISCV::VFMSAC_VV:
  case RISCV::VFMSAC_VF:
  case RISCV::VFNMSAC_VV:
  case RISCV::VFNMSAC_VF:
  case RISCV::VFMADD_VV:
  case RISCV::VFMADD_VF:
  case RISCV::VFNMADD_VV:
  case RISCV::VFNMADD_VF:
  case RISCV::VFMSUB_VV:
  case RISCV::VFMSUB_VF:
  case RISCV::VFNMSUB_VV:
  case RISCV::VFNMSUB_VF:
  // Vector Widening Floating-Point Fused Multiply-Add Instructions
  case RISCV::VFWMACC_VV:
  case RISCV::VFWMACC_VF:
  case RISCV::VFWNMACC_VV:
  case RISCV::VFWNMACC_VF:
  case RISCV::VFWMSAC_VV:
  case RISCV::VFWMSAC_VF:
  case RISCV::VFWNMSAC_VV:
  case RISCV::VFWNMSAC_VF:
  case RISCV::VFWMACCBF16_VV:
  case RISCV::VFWMACCBF16_VF:
  // Vector Floating-Point Square-Root Instruction
  case RISCV::VFSQRT_V:
  // Vector Floating-Point Reciprocal Square-Root Estimate Instruction
  case RISCV::VFRSQRT7_V:
  // Vector Floating-Point Reciprocal Estimate Instruction
  case RISCV::VFREC7_V:
  // Vector Floating-Point MIN/MAX Instructions
  case RISCV::VFMIN_VF:
  case RISCV::VFMIN_VV:
  case RISCV::VFMAX_VF:
  case RISCV::VFMAX_VV:
  // Vector Floating-Point Sign-Injection Instructions
  case RISCV::VFSGNJ_VF:
  case RISCV::VFSGNJ_VV:
  case RISCV::VFSGNJN_VV:
  case RISCV::VFSGNJN_VF:
  case RISCV::VFSGNJX_VF:
  case RISCV::VFSGNJX_VV:
  // Vector Floating-Point Compare Instructions
  case RISCV::VMFEQ_VF:
  case RISCV::VMFEQ_VV:
  case RISCV::VMFNE_VF:
  case RISCV::VMFNE_VV:
  case RISCV::VMFLT_VF:
  case RISCV::VMFLT_VV:
  case RISCV::VMFLE_VF:
  case RISCV::VMFLE_VV:
  case RISCV::VMFGT_VF:
  case RISCV::VMFGE_VF:
  // Vector Floating-Point Merge Instruction
  case RISCV::VFMERGE_VFM:
  // Vector Floating-Point Move Instruction
  case RISCV::VFMV_V_F:
  // Single-Width Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFCVT_XU_F_V:
  case RISCV::VFCVT_X_F_V:
  case RISCV::VFCVT_RTZ_XU_F_V:
  case RISCV::VFCVT_RTZ_X_F_V:
  case RISCV::VFCVT_F_XU_V:
  case RISCV::VFCVT_F_X_V:
  // Widening Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFWCVT_XU_F_V:
  case RISCV::VFWCVT_X_F_V:
  case RISCV::VFWCVT_RTZ_XU_F_V:
  case RISCV::VFWCVT_RTZ_X_F_V:
  case RISCV::VFWCVT_F_XU_V:
  case RISCV::VFWCVT_F_X_V:
  case RISCV::VFWCVT_F_F_V:
  case RISCV::VFWCVTBF16_F_F_V:
  // Narrowing Floating-Point/Integer Type-Convert Instructions
  case RISCV::VFNCVT_XU_F_W:
  case RISCV::VFNCVT_X_F_W:
  case RISCV::VFNCVT_RTZ_XU_F_W:
  case RISCV::VFNCVT_RTZ_X_F_W:
  case RISCV::VFNCVT_F_XU_W:
  case RISCV::VFNCVT_F_X_W:
  case RISCV::VFNCVT_F_F_W:
  case RISCV::VFNCVT_ROD_F_F_W:
  case RISCV::VFNCVTBF16_F_F_W:
    return true;
  }

  return false;
}

/// Return true if MO is a vector operand but is used as a scalar operand.
static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
  const MachineInstr *MI = MO.getParent();
  const RISCVVPseudosTable::PseudoInfo *RVV =
      RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());

  if (!RVV)
    return false;

  switch (RVV->BaseInstr) {
  // Reductions only use vs1[0] of vs1
  case RISCV::VREDAND_VS:
  case RISCV::VREDMAX_VS:
  case RISCV::VREDMAXU_VS:
  case RISCV::VREDMIN_VS:
  case RISCV::VREDMINU_VS:
  case RISCV::VREDOR_VS:
  case RISCV::VREDSUM_VS:
  case RISCV::VREDXOR_VS:
  case RISCV::VWREDSUM_VS:
  case RISCV::VWREDSUMU_VS:
  case RISCV::VFREDMAX_VS:
  case RISCV::VFREDMIN_VS:
  case RISCV::VFREDOSUM_VS:
  case RISCV::VFREDUSUM_VS:
  case RISCV::VFWREDOSUM_VS:
  case RISCV::VFWREDUSUM_VS:
    return MO.getOperandNo() == 3;
  case RISCV::VMV_X_S:
  case RISCV::VFMV_F_S:
    return MO.getOperandNo() == 1;
  default:
    return false;
  }
}

/// Return true if MI may read elements past VL.
static bool mayReadPastVL(const MachineInstr &MI) {
  const RISCVVPseudosTable::PseudoInfo *RVV =
      RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
  if (!RVV)
    return true;

  switch (RVV->BaseInstr) {
  // vslidedown instructions may read elements past VL. They are handled
  // according to current tail policy.
  case RISCV::VSLIDEDOWN_VI:
  case RISCV::VSLIDEDOWN_VX:
  case RISCV::VSLIDE1DOWN_VX:
  case RISCV::VFSLIDE1DOWN_VF:

  // vrgather instructions may read the source vector at any index < VLMAX,
  // regardless of VL.
  case RISCV::VRGATHER_VI:
  case RISCV::VRGATHER_VV:
  case RISCV::VRGATHER_VX:
  case RISCV::VRGATHEREI16_VV:
    return true;

  default:
    return false;
  }
}

bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
  const MCInstrDesc &Desc = MI.getDesc();
  if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
    return false;

  if (MI.getNumExplicitDefs() != 1)
    return false;

  // Some instructions have implicit defs e.g. $vxsat. If they might be read
  // later then we can't reduce VL.
  if (!MI.allImplicitDefsAreDead()) {
    LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
    return false;
  }

  if (MI.mayRaiseFPException()) {
    LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
    return false;
  }

  // Some instructions that produce vectors have semantics that make it more
  // difficult to determine whether the VL can be reduced. For example, some
  // instructions, such as reductions, may write lanes past VL to a scalar
  // register. Other instructions, such as some loads or stores, may write
  // lower lanes using data from higher lanes. There may be other complex
  // semantics not mentioned here that make it hard to determine whether
  // the VL can be optimized. As a result, a white-list of supported
  // instructions is used. Over time, more instructions can be supported
  // upon careful examination of their semantics under the logic in this
  // optimization.
  // TODO: Use a better approach than a white-list, such as adding
  // properties to instructions using something like TSFlags.
  if (!isSupportedInstr(MI)) {
    LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
    return false;
  }

  assert(!RISCVII::elementsDependOnVL(RISCV::getRVVMCOpcode(MI.getOpcode())) &&
         "Instruction shouldn't be supported if elements depend on VL");

  assert(MI.getOperand(0).isReg() &&
         isVectorRegClass(MI.getOperand(0).getReg(), MRI) &&
         "All supported instructions produce a vector register result");

  LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
  return true;
}

std::optional<MachineOperand>
RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
  const MachineInstr &UserMI = *UserOp.getParent();
  const MCInstrDesc &Desc = UserMI.getDesc();

  if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
    LLVM_DEBUG(dbgs() << "    Abort due to lack of VL, assume that"
                         " use VLMAX\n");
    return std::nullopt;
  }

  if (mayReadPastVL(UserMI)) {
    LLVM_DEBUG(dbgs() << "    Abort because used by unsafe instruction\n");
    return std::nullopt;
  }

  unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
  const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
  // Looking for an immediate or a register VL that isn't X0.
  assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
         "Did not expect X0 VL");

  // If the user is a passthru it will read the elements past VL, so
  // abort if any of the elements past VL are demanded.
  if (UserOp.isTied()) {
    assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
           RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
    auto DemandedVL = DemandedVLs.lookup(&UserMI);
    if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
      LLVM_DEBUG(dbgs() << "    Abort because user is passthru in "
                           "instruction with demanded tail\n");
      return std::nullopt;
    }
  }

  // Instructions like reductions may use a vector register as a scalar
  // register. In this case, we should treat it as only reading the first lane.
  if (isVectorOpUsedAsScalarOp(UserOp)) {
    LLVM_DEBUG(dbgs() << "    Used this operand as a scalar operand\n");
    return MachineOperand::CreateImm(1);
  }

  // If we know the demanded VL of UserMI, then we can reduce the VL it
  // requires.
  if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) {
    assert(isCandidate(UserMI));
    if (RISCV::isVLKnownLE(*DemandedVL, VLOp))
      return DemandedVL;
  }

  return VLOp;
}

std::optional<MachineOperand>
RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
  std::optional<MachineOperand> CommonVL;
  SmallSetVector<MachineOperand *, 8> Worklist;
  SmallPtrSet<const MachineInstr *, 4> PHISeen;
  for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg()))
    Worklist.insert(&UserOp);

  while (!Worklist.empty()) {
    MachineOperand &UserOp = *Worklist.pop_back_val();
    const MachineInstr &UserMI = *UserOp.getParent();
    LLVM_DEBUG(dbgs() << "  Checking user: " << UserMI << "\n");

    if (UserMI.isFullCopy() && UserMI.getOperand(0).getReg().isVirtual()) {
      LLVM_DEBUG(dbgs() << "    Peeking through uses of COPY\n");
      Worklist.insert_range(llvm::make_pointer_range(
          MRI->use_operands(UserMI.getOperand(0).getReg())));
      continue;
    }

    if (UserMI.isPHI()) {
      // Don't follow PHI cycles
      if (!PHISeen.insert(&UserMI).second)
        continue;
      LLVM_DEBUG(dbgs() << "    Peeking through uses of PHI\n");
      Worklist.insert_range(llvm::make_pointer_range(
          MRI->use_operands(UserMI.getOperand(0).getReg())));
      continue;
    }

    auto VLOp = getMinimumVLForUser(UserOp);
    if (!VLOp)
      return std::nullopt;

    // Use the largest VL among all the users. If we cannot determine this
    // statically, then we cannot optimize the VL.
    if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
      CommonVL = *VLOp;
      LLVM_DEBUG(dbgs() << "    User VL is: " << VLOp << "\n");
    } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
      LLVM_DEBUG(dbgs() << "    Abort because cannot determine a common VL\n");
      return std::nullopt;
    }

    if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
      LLVM_DEBUG(dbgs() << "    Abort due to lack of SEW operand\n");
      return std::nullopt;
    }

    std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
    std::optional<OperandInfo> ProducerInfo =
        getOperandInfo(MI.getOperand(0), MRI);
    if (!ConsumerInfo || !ProducerInfo) {
      LLVM_DEBUG(dbgs() << "    Abort due to unknown operand information.\n");
      LLVM_DEBUG(dbgs() << "      ConsumerInfo is: " << ConsumerInfo << "\n");
      LLVM_DEBUG(dbgs() << "      ProducerInfo is: " << ProducerInfo << "\n");
      return std::nullopt;
    }

    // If the operand is used as a scalar operand, then the EEW must be
    // compatible. Otherwise, the EMUL *and* EEW must be compatible.
    bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp);
    if ((IsVectorOpUsedAsScalarOp &&
         !OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) ||
        (!IsVectorOpUsedAsScalarOp &&
         !OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) {
      LLVM_DEBUG(
          dbgs()
          << "    Abort due to incompatible information for EMUL or EEW.\n");
      LLVM_DEBUG(dbgs() << "      ConsumerInfo is: " << ConsumerInfo << "\n");
      LLVM_DEBUG(dbgs() << "      ProducerInfo is: " << ProducerInfo << "\n");
      return std::nullopt;
    }
  }

  return CommonVL;
}

bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
  LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");

  unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
  MachineOperand &VLOp = MI.getOperand(VLOpNum);

  // If the VL is 1, then there is no need to reduce it. This is an
  // optimization, not needed to preserve correctness.
  if (VLOp.isImm() && VLOp.getImm() == 1) {
    LLVM_DEBUG(dbgs() << "  Abort due to VL == 1, no point in reducing.\n");
    return false;
  }

  auto CommonVL = DemandedVLs.lookup(&MI);
  if (!CommonVL)
    return false;

  assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
         "Expected VL to be an Imm or virtual Reg");

  if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
    LLVM_DEBUG(dbgs() << "    Abort due to CommonVL not <= VLOp.\n");
    return false;
  }

  if (CommonVL->isIdenticalTo(VLOp)) {
    LLVM_DEBUG(
        dbgs() << "    Abort due to CommonVL == VLOp, no point in reducing.\n");
    return false;
  }

  if (CommonVL->isImm()) {
    LLVM_DEBUG(dbgs() << "  Reduce VL from " << VLOp << " to "
                      << CommonVL->getImm() << " for " << MI << "\n");
    VLOp.ChangeToImmediate(CommonVL->getImm());
    return true;
  }
  const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
  if (!MDT->dominates(VLMI, &MI))
    return false;
  LLVM_DEBUG(
      dbgs() << "  Reduce VL from " << VLOp << " to "
             << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
             << " for " << MI << "\n");

  // All our checks passed. We can reduce VL.
  VLOp.ChangeToRegister(CommonVL->getReg(), false);
  return true;
}

bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
  assert(DemandedVLs.size() == 0);
  if (skipFunction(MF.getFunction()))
    return false;

  MRI = &MF.getRegInfo();
  MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();

  const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
  if (!ST.hasVInstructions())
    return false;

  // For each instruction that defines a vector, compute what VL its
  // downstream users demand.
  for (MachineBasicBlock *MBB : post_order(&MF)) {
    assert(MDT->isReachableFromEntry(MBB));
    for (MachineInstr &MI : reverse(*MBB)) {
      if (!isCandidate(MI))
        continue;
      DemandedVLs.insert({&MI, checkUsers(MI)});
    }
  }

  // Then go through and see if we can reduce the VL of any instructions to
  // only what's demanded.
  bool MadeChange = false;
  for (MachineBasicBlock &MBB : MF) {
    // Avoid unreachable blocks as they have degenerate dominance
    if (!MDT->isReachableFromEntry(&MBB))
      continue;

    for (auto &MI : reverse(MBB)) {
      if (!isCandidate(MI))
        continue;
      if (!tryReduceVL(MI))
        continue;
      MadeChange = true;
    }
  }

  DemandedVLs.clear();
  return MadeChange;
}
