//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//

#include "DXILDataScalarization.h"
#include "DirectX.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"

#define DEBUG_TYPE "dxil-data-scalarization"
static const int MaxVecSize = 4;

using namespace llvm;

// Recursively creates an array-like version of a given vector type.
static Type *equivalentArrayTypeFromVector(Type *T) {
  if (auto *VecTy = dyn_cast<VectorType>(T))
    return ArrayType::get(VecTy->getElementType(),
                          dyn_cast<FixedVectorType>(VecTy)->getNumElements());
  if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
    Type *NewElementType =
        equivalentArrayTypeFromVector(ArrayTy->getElementType());
    return ArrayType::get(NewElementType, ArrayTy->getNumElements());
  }
  // If it's not a vector or array, return the original type.
  return T;
}

class DXILDataScalarizationLegacy : public ModulePass {

public:
  bool runOnModule(Module &M) override;
  DXILDataScalarizationLegacy() : ModulePass(ID) {}

  static char ID; // Pass identification.
};

static bool findAndReplaceVectors(Module &M);

class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
  DataScalarizerVisitor() : GlobalMap() {}
  bool visit(Function &F);
  // InstVisitor methods.  They return true if the instruction was scalarized,
  // false if nothing changed.
  bool visitAllocaInst(AllocaInst &AI);
  bool visitInstruction(Instruction &I) { return false; }
  bool visitSelectInst(SelectInst &SI) { return false; }
  bool visitICmpInst(ICmpInst &ICI) { return false; }
  bool visitFCmpInst(FCmpInst &FCI) { return false; }
  bool visitUnaryOperator(UnaryOperator &UO) { return false; }
  bool visitBinaryOperator(BinaryOperator &BO) { return false; }
  bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
  bool visitCastInst(CastInst &CI) { return false; }
  bool visitBitCastInst(BitCastInst &BCI) { return false; }
  bool visitInsertElementInst(InsertElementInst &IEI);
  bool visitExtractElementInst(ExtractElementInst &EEI);
  bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
  bool visitPHINode(PHINode &PHI) { return false; }
  bool visitLoadInst(LoadInst &LI);
  bool visitStoreInst(StoreInst &SI);
  bool visitCallInst(CallInst &ICI) { return false; }
  bool visitFreezeInst(FreezeInst &FI) { return false; }
  friend bool findAndReplaceVectors(llvm::Module &M);

private:
  typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
  typedef SmallDenseMap<Value *, AllocaAndGEPs>
      VectorToArrayMap; // A map from a vector-typed Value to its corresponding
                        // AllocaInst and GEPs to each element of an array
  VectorToArrayMap VectorAllocaMap;
  AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
                                      const Twine &Name);
  bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
  bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);

  GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
  DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
};

bool DataScalarizerVisitor::visit(Function &F) {
  bool MadeChange = false;
  ReversePostOrderTraversal<Function *> RPOT(&F);
  for (BasicBlock *BB : make_early_inc_range(RPOT)) {
    for (Instruction &I : make_early_inc_range(*BB))
      MadeChange |= InstVisitor::visit(I);
  }
  VectorAllocaMap.clear();
  return MadeChange;
}

GlobalVariable *
DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
  if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
    auto It = GlobalMap.find(OldGlobal);
    if (It != GlobalMap.end()) {
      return It->second; // Found, return the new global
    }
  }
  return nullptr; // Not found
}

// Helper function to check if a type is a vector or an array of vectors
static bool isVectorOrArrayOfVectors(Type *T) {
  if (isa<VectorType>(T))
    return true;
  if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
    return isa<VectorType>(ArrType->getElementType()) ||
           isVectorOrArrayOfVectors(ArrType->getElementType());
  return false;
}

bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
  Type *AllocatedType = AI.getAllocatedType();
  if (!isVectorOrArrayOfVectors(AllocatedType))
    return false;

  IRBuilder<> Builder(&AI);
  Type *NewType = equivalentArrayTypeFromVector(AllocatedType);
  AllocaInst *ArrAlloca =
      Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
  ArrAlloca->setAlignment(AI.getAlign());
  AI.replaceAllUsesWith(ArrAlloca);
  AI.eraseFromParent();
  return true;
}

bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
  Value *PtrOperand = LI.getPointerOperand();
  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
    OldGEP->insertBefore(LI.getIterator());
    IRBuilder<> Builder(&LI);
    LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
    NewLoad->setAlignment(LI.getAlign());
    LI.replaceAllUsesWith(NewLoad);
    LI.eraseFromParent();
    visitGetElementPtrInst(*OldGEP);
    return true;
  }
  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
    LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
  return false;
}

bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {

  Value *PtrOperand = SI.getPointerOperand();
  ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
  if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
    GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
    OldGEP->insertBefore(SI.getIterator());
    IRBuilder<> Builder(&SI);
    StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
    NewStore->setAlignment(SI.getAlign());
    SI.replaceAllUsesWith(NewStore);
    SI.eraseFromParent();
    visitGetElementPtrInst(*OldGEP);
    return true;
  }
  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
    SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);

  return false;
}

DataScalarizerVisitor::AllocaAndGEPs
DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
                                             const Twine &Name = "") {
  // If there is already an alloca for this vector, return it
  if (VectorAllocaMap.contains(Vec))
    return VectorAllocaMap[Vec];

  auto InsertPoint = Builder.GetInsertPoint();

  // Allocate the array to hold the vector elements
  Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
  Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
  AllocaInst *ArrAlloca =
      Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
  const uint64_t ArrNumElems = ArrTy->getArrayNumElements();

  // Create loads and stores to populate the array immediately after the
  // original vector's defining instruction if available, else immediately after
  // the alloca
  if (auto *Instr = dyn_cast<Instruction>(Vec))
    Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
  SmallVector<Value *, 4> GEPs(ArrNumElems);
  for (unsigned I = 0; I < ArrNumElems; ++I) {
    Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
    GEPs[I] = Builder.CreateInBoundsGEP(
        ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
        Name + ".index");
    Builder.CreateStore(EE, GEPs[I]);
  }

  VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
  Builder.SetInsertPoint(InsertPoint);
  return {ArrAlloca, GEPs};
}

/// Returns a pair of Value* with the first being a GEP into ArrAlloca using
/// indices {0, Index}, and the second Value* being a Load of the GEP
static std::pair<Value *, Value *>
dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
                     const Twine &Name = "") {
  Type *ArrTy = ArrAlloca->getAllocatedType();
  Value *GEP = Builder.CreateInBoundsGEP(
      ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
  Value *Load =
      Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
  return std::make_pair(GEP, Load);
}

bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
    InsertElementInst &IEI) {
  IRBuilder<> Builder(&IEI);

  Value *Vec = IEI.getOperand(0);
  Value *Val = IEI.getOperand(1);
  Value *Index = IEI.getOperand(2);

  AllocaAndGEPs ArrAllocaAndGEPs =
      createArrayFromVector(Builder, Vec, IEI.getName());
  AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
  Type *ArrTy = ArrAlloca->getAllocatedType();
  SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;

  auto GEPAndLoad =
      dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
  Value *GEP = GEPAndLoad.first;
  Value *Load = GEPAndLoad.second;

  Builder.CreateStore(Val, GEP);
  Value *NewIEI = PoisonValue::get(Vec->getType());
  for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
    Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
                                     IEI.getName() + ".load");
    NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
                                         IEI.getName() + ".insert");
  }

  // Store back the original value so the Alloca can be reused for subsequent
  // insertelement instructions on the same vector
  Builder.CreateStore(Load, GEP);

  IEI.replaceAllUsesWith(NewIEI);
  IEI.eraseFromParent();
  return true;
}

bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
  // If the index is a constant then we don't need to scalarize it
  Value *Index = IEI.getOperand(2);
  if (isa<ConstantInt>(Index))
    return false;
  return replaceDynamicInsertElementInst(IEI);
}

bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
    ExtractElementInst &EEI) {
  IRBuilder<> Builder(&EEI);

  AllocaAndGEPs ArrAllocaAndGEPs =
      createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
  AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;

  auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
                                         EEI.getIndexOperand(), EEI.getName());
  Value *Load = GEPAndLoad.second;

  EEI.replaceAllUsesWith(Load);
  EEI.eraseFromParent();
  return true;
}

bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
  // If the index is a constant then we don't need to scalarize it
  Value *Index = EEI.getIndexOperand();
  if (isa<ConstantInt>(Index))
    return false;
  return replaceDynamicExtractElementInst(EEI);
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
  Value *PtrOperand = GEPI.getPointerOperand();
  Type *OrigGEPType = GEPI.getSourceElementType();
  Type *NewGEPType = OrigGEPType;
  bool NeedsTransform = false;

  if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
    NewGEPType = NewGlobal->getValueType();
    PtrOperand = NewGlobal;
    NeedsTransform = true;
  } else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
    Type *AllocatedType = Alloca->getAllocatedType();
    // Only transform if the allocated type is an array
    if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) {
      NewGEPType = AllocatedType;
      NeedsTransform = true;
    }
  }

  // Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
  // convert these scalar geps into flattened array geps
  if (!isa<ArrayType>(OrigGEPType))
    NewGEPType = OrigGEPType;

  // Note: We bail if this isn't a gep touched via alloca or global
  // transformations
  if (!NeedsTransform)
    return false;

  IRBuilder<> Builder(&GEPI);
  SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());

  Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
                                    GEPI.getName(), GEPI.getNoWrapFlags());
  GEPI.replaceAllUsesWith(NewGEP);
  GEPI.eraseFromParent();
  return true;
}

static Constant *transformInitializer(Constant *Init, Type *OrigType,
                                      Type *NewType, LLVMContext &Ctx) {
  // Handle ConstantAggregateZero (zero-initialized constants)
  if (isa<ConstantAggregateZero>(Init)) {
    return ConstantAggregateZero::get(NewType);
  }

  // Handle UndefValue (undefined constants)
  if (isa<UndefValue>(Init)) {
    return UndefValue::get(NewType);
  }

  // Handle vector to array transformation
  if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
    // Convert vector initializer to array initializer
    SmallVector<Constant *, MaxVecSize> ArrayElements;
    if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
      for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
        ArrayElements.push_back(ConstVecInit->getOperand(I));
    } else if (ConstantDataVector *ConstDataVecInit =
                   llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
      for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
        ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
    } else {
      assert(false && "Expected a ConstantVector or ConstantDataVector for "
                      "vector initializer!");
    }

    return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
  }

  // Handle array of vectors transformation
  if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
    auto *ArrayInit = dyn_cast<ConstantArray>(Init);
    assert(ArrayInit && "Expected a ConstantArray for array initializer!");

    SmallVector<Constant *, MaxVecSize> NewArrayElements;
    for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
      // Recursively transform array elements
      Constant *NewElemInit = transformInitializer(
          ArrayInit->getOperand(I), ArrayTy->getElementType(),
          cast<ArrayType>(NewType)->getElementType(), Ctx);
      NewArrayElements.push_back(NewElemInit);
    }

    return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
  }

  // If not a vector or array, return the original initializer
  return Init;
}

static bool findAndReplaceVectors(Module &M) {
  bool MadeChange = false;
  LLVMContext &Ctx = M.getContext();
  IRBuilder<> Builder(Ctx);
  DataScalarizerVisitor Impl;
  for (GlobalVariable &G : M.globals()) {
    Type *OrigType = G.getValueType();

    Type *NewType = equivalentArrayTypeFromVector(OrigType);
    if (OrigType != NewType) {
      // Create a new global variable with the updated type
      // Note: Initializer is set via transformInitializer
      GlobalVariable *NewGlobal = new GlobalVariable(
          M, NewType, G.isConstant(), G.getLinkage(),
          /*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
          G.getThreadLocalMode(), G.getAddressSpace(),
          G.isExternallyInitialized());

      // Copy relevant attributes
      NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
      if (G.getAlignment() > 0) {
        NewGlobal->setAlignment(G.getAlign());
      }

      if (G.hasInitializer()) {
        Constant *Init = G.getInitializer();
        Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
        NewGlobal->setInitializer(NewInit);
      }

      // Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
      // type equality. Instead we will use the visitor pattern.
      Impl.GlobalMap[&G] = NewGlobal;
    }
  }

  for (auto &F : make_early_inc_range(M.functions())) {
    if (F.isDeclaration())
      continue;
    MadeChange |= Impl.visit(F);
  }

  // Remove the old globals after the iteration
  for (auto &[Old, New] : Impl.GlobalMap) {
    Old->eraseFromParent();
    MadeChange = true;
  }
  return MadeChange;
}

PreservedAnalyses DXILDataScalarization::run(Module &M,
                                             ModuleAnalysisManager &) {
  bool MadeChanges = findAndReplaceVectors(M);
  if (!MadeChanges)
    return PreservedAnalyses::all();
  PreservedAnalyses PA;
  return PA;
}

bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
  return findAndReplaceVectors(M);
}

char DXILDataScalarizationLegacy::ID = 0;

INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
                      "DXIL Data Scalarization", false, false)
INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
                    "DXIL Data Scalarization", false, false)

ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
  return new DXILDataScalarizationLegacy();
}
