// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py

// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN:  --sparsification --sparse-tensor-codegen \
// RUN:  --canonicalize --cse | FileCheck %s

#CSR = #sparse_tensor.encoding<{
  dimLevelType = [ "dense", "compressed" ],
  dimOrdering = affine_map<(i,j) -> (i,j)>
}>

//
// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
//
// CHECK-LABEL:   func.func private @_insert_dense_compressed_4_4_f64_0_0(
// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf64>,
// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME:      %[[VAL_4:.*4]]: index,
// CHECK-SAME:      %[[VAL_5:.*5]]: index,
// CHECK-SAME:      %[[VAL_6:.*6]]: f64) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_7:.*]] = arith.constant false
// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index
// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
// CHECK:           %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  idx_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index
// CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index
// CHECK:           %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index
// CHECK:           %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) {
// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<?xindex>
// CHECK:             %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index
// CHECK:             scf.yield %[[VAL_18]] : i1
// CHECK:           } else {
// CHECK:             memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK:             scf.yield %[[VAL_7]] : i1
// CHECK:           }
// CHECK:           %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK:             scf.yield %[[VAL_1]], %[[VAL_3]] : memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK:           } else {
// CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index
// CHECK:             memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<?xindex>
// CHECK:             %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref<?xindex>, index
// CHECK:             %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64
// CHECK:             %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]]  idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier
// CHECK:             scf.yield %[[VAL_22]], %[[VAL_25]] : memref<?xindex>, !sparse_tensor.storage_specifier
// CHECK:           }
// CHECK:           %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1  val_mem_sz : !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index
// CHECK:           %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref<?xf64>, f64
// CHECK:           %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64
// CHECK:           %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1  val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier
// CHECK:           return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:         }

// CHECK-LABEL:   func.func @matmul(
// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf64>,
// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
// CHECK-SAME:      %[[VAL_4:.*4]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_5:.*5]]: memref<?xindex>,
// CHECK-SAME:      %[[VAL_6:.*6]]: memref<?xf64>,
// CHECK-SAME:      %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_8:.*]] = arith.constant 4 : index
// CHECK:           %[[VAL_9:.*]] = arith.constant 4 : i64
// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64
// CHECK:           %[[VAL_11:.*]] = arith.constant 0 : index
// CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
// CHECK:           %[[VAL_13:.*]] = arith.constant false
// CHECK:           %[[VAL_14:.*]] = arith.constant true
// CHECK:           %[[VAL_15:.*]] = memref.alloc() : memref<16xindex>
// CHECK:           %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref<?xindex>
// CHECK:           %[[VAL_17:.*]] = memref.alloc() : memref<16xindex>
// CHECK:           %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref<?xindex>
// CHECK:           %[[VAL_19:.*]] = memref.alloc() : memref<16xf64>
// CHECK:           %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref<?xf64>
// CHECK:           %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]]  dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]]  dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]]  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index
// CHECK:           %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref<?xindex>, index
// CHECK:           %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64
// CHECK:           %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]]  ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref<?xindex>, index, index
// CHECK:           %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64
// CHECK:           %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]]  ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_36:.*]] = memref.alloc() : memref<4xf64>
// CHECK:           %[[VAL_37:.*]] = memref.alloc() : memref<4xi1>
// CHECK:           %[[VAL_38:.*]] = memref.alloc() : memref<4xindex>
// CHECK:           %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref<?xindex>
// CHECK:           linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>)
// CHECK:           linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>)
// CHECK:           %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref<?xindex>
// CHECK:             %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index
// CHECK:             %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref<?xindex>
// CHECK:             %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) {
// CHECK:               %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref<?xindex>
// CHECK:               %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref<?xf64>
// CHECK:               %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref<?xindex>
// CHECK:               %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index
// CHECK:               %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref<?xindex>
// CHECK:               %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) {
// CHECK:                 %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref<?xindex>
// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
// CHECK:                 %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref<?xf64>
// CHECK:                 %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64
// CHECK:                 %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64
// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
// CHECK:                 %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1
// CHECK:                 %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) {
// CHECK:                   memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1>
// CHECK:                   memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex>
// CHECK:                   %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index
// CHECK:                   scf.yield %[[VAL_68]] : index
// CHECK:                 } else {
// CHECK:                   scf.yield %[[VAL_59]] : index
// CHECK:                 }
// CHECK:                 memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64>
// CHECK:                 scf.yield %[[VAL_69:.*]] : index
// CHECK:               } {"Emitted from" = "linalg.generic"}
// CHECK:               scf.yield %[[VAL_70:.*]] : index
// CHECK:             } {"Emitted from" = "linalg.generic"}
// CHECK:             sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
// CHECK:             %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:               %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
// CHECK:               %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
// CHECK:               %[[VAL_80:.*]]:4 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:               memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
// CHECK:               memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1>
// CHECK:               scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:             }
// CHECK:             scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK:           } {"Emitted from" = "linalg.generic"}
// CHECK:           memref.dealloc %[[VAL_36]] : memref<4xf64>
// CHECK:           memref.dealloc %[[VAL_37]] : memref<4xi1>
// CHECK:           memref.dealloc %[[VAL_38]] : memref<4xindex>
// CHECK:           %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3  ptr_mem_sz at 1 : !sparse_tensor.storage_specifier
// CHECK:           %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index
// CHECK:           %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref<?xindex>
// CHECK:           %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) {
// CHECK:             %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
// CHECK:             %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index
// CHECK:             %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index
// CHECK:             scf.if %[[VAL_90]] {
// CHECK:               memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref<?xindex>
// CHECK:             }
// CHECK:             scf.yield %[[VAL_91]] : index
// CHECK:           }
// CHECK:           return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
                  %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
  %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
  %D = linalg.matmul
    ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
       outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
  return %D: tensor<4x4xf64, #CSR>
}
