// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: odml-to-stablehlo-opt -tf-legalize-hlo -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL:   func @biasAdd_NHWC(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1x32x10x32xi32>,
// CHECK-SAME:                       %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x32x10x32xi32>
// CHECK:         }
func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 3>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
  func.return %0 : tensor<1x32x10x32xi32>
}

// CHECK-LABEL:   func @biasAdd_NCHW(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1x32x10x32xi32>,
// CHECK-SAME:                       %[[VAL_1:.*]]: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x32x10x32xi32>
// CHECK:         }
func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 3>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
  func.return %0 : tensor<1x32x10x32xi32>
}

// CHECK-LABEL:   func @biasAdd_dynamic(
// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<?x?x?x?xi32>,
// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
// CHECK:           return %[[VAL_2]] : tensor<?x?x?x?xi32>
// CHECK:         }
func.func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 3>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
  func.return %0 : tensor<?x?x?x?xi32>
}

// CHECK-LABEL:   func @add(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_1]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_2]] : tensor<2xi32>
// CHECK:         }
func.func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
  %0 = mhlo.add %arg0, %arg0 : tensor<2xi32>
  %1 = mhlo.add %0, %arg0 : tensor<2xi32>
  func.return %1 : tensor<2xi32>
}

// CHECK-LABEL:   func @broadcast_add(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.AddV2"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_add(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.add %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.add %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @broadcast_add_chlo(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi32>
// CHECK:         }
func.func @broadcast_add_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
  func.return %0 : tensor<1x2xi32>
}

// CHECK-LABEL:   func @broadcast_multi_dim_add(
// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<4x1x1xi32>,
// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.AddV2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4x4x4x4xi32>
// CHECK:         }
func.func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1, 2, 3>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
  func.return %0 : tensor<4x4x4x4xi32>
}

// CHECK-LABEL:   func @unsupported_broadcast_add
// CHECK: chlo.broadcast_add
func.func @unsupported_broadcast_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
  %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = array<i64: 0, 1, 2>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
  func.return %0 : tensor<4x4x4x4xi32>
}

// CHECK-LABEL:   func @div(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_1]] : tensor<2xi32>
// CHECK:         }
func.func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
  %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @broadcast_div(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Div"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_div(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.divide %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.divide %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @broadcast_div_chlo(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi32>
// CHECK:         }
func.func @broadcast_div_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
  func.return %0 : tensor<1x2xi32>
}

// CHECK-LABEL:   func @shift_left(
// CHECK-SAME:                     %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                     %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.LeftShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4xi32>
// CHECK:         }
func.func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
  %0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
  func.return %0 : tensor<4xi32>
}

// CHECK-LABEL:   func @broadcast_shift_left(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
// CHECK-DAG.       %cst = arith.constant dense<[4]> : tensor<1xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.LeftShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           %[[VAL_3:.*]] = "tf.LeftShift"(%[[VAL_1]], %[[VAL_0]]) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<4xi32>, tensor<4xi32>
// CHECK:         }
func.func @broadcast_shift_left(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<4xi32>
  %1 = mhlo.shift_left %0, %arg1 : tensor<4xi32>
  %2 = mhlo.shift_left %arg1, %0 : tensor<4xi32>
  func.return %1, %2 : tensor<4xi32>, tensor<4xi32>
}

// CHECK-LABEL:   func @div_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK:           return %[[VAL_2]] : tensor<?x?xi32>
// CHECK:         }
func.func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
  %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
  func.return %0 : tensor<?x?xi32>
}

// CHECK-LABEL:   func @maximum(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<4xf32>,
// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK:           return %[[VAL_2]] : tensor<4xf32>
// CHECK:         }
func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
  %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
  func.return %0 : tensor<4xf32>
}

// CHECK-LABEL:   func @broadcast_maximum(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_maximum(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.maximum %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.maximum %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @minimum(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<4xf32>,
// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<4xf32>) -> tensor<4xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK:           return %[[VAL_2]] : tensor<4xf32>
// CHECK:         }
func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
  %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
  func.return %0 : tensor<4xf32>
}

// CHECK-LABEL:   func @broadcast_minimum(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_minimum(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.minimum %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.minimum %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @mul(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_1]] : tensor<2xi32>
// CHECK:         }
func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
  %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @broadcast_mul(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Mul"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_mul(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.multiply %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.multiply %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @broadcast_mul_chlo(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi32>
// CHECK:         }
func.func @broadcast_mul_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
  func.return %0 : tensor<1x2xi32>
}

// CHECK-LABEL:   func @real_div(
// CHECK-SAME:                   %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_1]] : tensor<2xi32>
// CHECK:         }
func.func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
  %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @broadcast_real_div(
// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi32>
// CHECK:         }
func.func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
  func.return %0 : tensor<1x2xi32>
}

// CHECK-LABEL:   func @sub(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_1]] : tensor<2xi32>
// CHECK:         }
func.func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
  %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @broadcast_sub(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[1, 1000]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xf32>, tensor<1x1000xf32>) -> tensor<1x1000xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Sub"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1000xf32>, tensor<1x1xf32>) -> tensor<1x1000xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<1x1000xf32>, tensor<1x1000xf32>
// CHECK:         }
func.func @broadcast_sub(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1000xf32>) -> (tensor<1x1000xf32>, tensor<1x1000xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<1x1000xf32>
  %1 = mhlo.subtract %0, %arg1 : tensor<1x1000xf32>
  %2 = mhlo.subtract %arg1, %0 : tensor<1x1000xf32>
  func.return %1, %2 : tensor<1x1000xf32>, tensor<1x1000xf32>
}

// CHECK-LABEL:   func @broadcast_sub_chlo(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Sub"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi32>
// CHECK:         }
func.func @broadcast_sub_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
  %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
  func.return %0 : tensor<1x2xi32>
}

// CHECK-LABEL:   func @broadcast_atan2(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xf32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
// CHECK-DAG.       %cst = arith.constant dense<[4]> : tensor<1xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Atan2"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.Atan2"(%[[VAL_1]], %[[VAL_0]]) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<4xf32>, tensor<4xf32>
// CHECK:         }
func.func @broadcast_atan2(%arg0: tensor<1xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xf32>) -> tensor<4xf32>
  %1 = mhlo.atan2 %0, %arg1 : tensor<4xf32>
  %2 = mhlo.atan2 %arg1, %0 : tensor<4xf32>
  func.return %1, %2 : tensor<4xf32>, tensor<4xf32>
}

// CHECK-LABEL:   func @shift_right(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4xi32>
// CHECK:         }
func.func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
  %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
  func.return %0 : tensor<4xi32>
}

// CHECK-LABEL:   func @broadcast_shift_right(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                                %[[VAL_1:.*]]: tensor<2x4xi32>) -> tensor<2x4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.RightShift"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// CHECK:           return %[[VAL_2]] : tensor<2x4xi32>
// CHECK:         }
func.func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
  %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
  func.return %0 : tensor<2x4xi32>
}

// CHECK-LABEL:   func @and(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xi1>,
// CHECK-SAME:              %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> {
  %0 = mhlo.and %arg0, %arg1 : tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @and_broadcast(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi1>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @and_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xi1>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<1xi1>) -> tensor<?xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
// CHECK:           return %[[VAL_2]] : tensor<?xi1>
// CHECK:         }
func.func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
  %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL:   func @or(
// CHECK-SAME:             %[[VAL_0:.*]]: tensor<2xi1>,
// CHECK-SAME:             %[[VAL_1:.*]]: tensor<2xi1>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> {
  %0 = mhlo.or %arg0, %arg1 : tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @or_broadcast(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1xi1>,
// CHECK-SAME:                       %[[VAL_1:.*]]: tensor<1x2xi1>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @or_dynamic(
// CHECK-SAME:                     %[[VAL_0:.*]]: tensor<?xi1>,
// CHECK-SAME:                     %[[VAL_1:.*]]: tensor<1xi1>) -> tensor<?xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
// CHECK:           return %[[VAL_2]] : tensor<?xi1>
// CHECK:         }
func.func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
  %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL:   func @bitwise_or(
// CHECK-SAME:                     %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                     %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4xi32>
// CHECK:         }
func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
  %0 = mhlo.or %arg0, %arg1 : tensor<4xi32>
  func.return %0 : tensor<4xi32>
}

// CHECK-LABEL:   func @bitwise_or_broadcast(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<1xi8>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
// CHECK:           return %[[VAL_2]] : tensor<1x4xi8>
// CHECK:         }
func.func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8>
  %1 = mhlo.or %0, %arg1 : tensor<1x4xi8>
  func.return %1 : tensor<1x4xi8>
}

// CHECK-LABEL:   func @bitwise_or_broadcast_chlo(
// CHECK-SAME:                                    %[[VAL_0:.*]]: tensor<1xi8>,
// CHECK-SAME:                                    %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
// CHECK:           return %[[VAL_2]] : tensor<1x4xi8>
// CHECK:         }
func.func @bitwise_or_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
  %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
  func.return %0 : tensor<1x4xi8>
}

// CHECK-LABEL:   func @bitwise_or_dynamic(
// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseOr"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
// CHECK:           return %[[VAL_2]] : tensor<?xi32>
// CHECK:         }
func.func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
  %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
  func.return %0 : tensor<?xi32>
}

// CHECK-LABEL:   func @bitwise_xor(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseXor"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4xi32>
// CHECK:         }
func.func @bitwise_xor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
  %0 = mhlo.xor %arg0, %arg1 : tensor<4xi32>
  func.return %0 : tensor<4xi32>
}

// CHECK-LABEL:   func @bitwise_xor_broadcast(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1xi8>,
// CHECK-SAME:                                %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseXor"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
// CHECK:           return %[[VAL_2]] : tensor<1x4xi8>
// CHECK:         }
func.func @bitwise_xor_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8>
  %1 = mhlo.xor %0, %arg1 : tensor<1x4xi8>
  func.return %1 : tensor<1x4xi8>
}

// CHECK-LABEL:   func @bitwise_and(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<4xi32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<4xi32>) -> tensor<4xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]] : tensor<4xi32>
// CHECK:         }
func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
  %0 = mhlo.and %arg0, %arg1 : tensor<4xi32>
  func.return %0 : tensor<4xi32>
}

// CHECK-LABEL:   func @bitwise_and_broadcast(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1xi8>,
// CHECK-SAME:                                %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
// CHECK:           return %[[VAL_2]] : tensor<1x4xi8>
// CHECK:         }
func.func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>}> : (tensor<1xi8>) -> tensor<1x4xi8>
  %1 = mhlo.and %0, %arg1 : tensor<1x4xi8>
  func.return %1 : tensor<1x4xi8>
}

// CHECK-LABEL:   func @bitwise_and_broadcast_chlo(
// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<1xi8>,
// CHECK-SAME:                                     %[[VAL_1:.*]]: tensor<1x4xi8>) -> tensor<1x4xi8> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
// CHECK:           return %[[VAL_2]] : tensor<1x4xi8>
// CHECK:         }
func.func @bitwise_and_broadcast_chlo(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
  %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
  func.return %0 : tensor<1x4xi8>
}

// CHECK-LABEL:   func @bitwise_and_dynamic(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BitwiseAnd"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
// CHECK:           return %[[VAL_2]] : tensor<?xi32>
// CHECK:         }
func.func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
  %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
  func.return %0 : tensor<?xi32>
}

// CHECK-LABEL:   func @pow(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = mhlo.power %arg0, %arg0 : tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @broadcast_pow(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
// CHECK-DAG.       %cst = arith.constant dense<[4]> : tensor<1xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK:           %[[VAL_3:.*]] = "tf.Pow"(%[[VAL_1]], %[[VAL_0]]) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_2]], %[[VAL_3]] : tensor<4xi32>, tensor<4xi32>
// CHECK:         }
func.func @broadcast_pow(%arg0: tensor<1xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<4xi32>
  %1 = mhlo.power %0, %arg1 : tensor<4xi32>
  %2 = mhlo.power %arg1, %0 : tensor<4xi32>
  func.return %1, %2 : tensor<4xi32>, tensor<4xi32>
}

// CHECK-LABEL:   func @pow_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Pow"(%[[VAL_0]], %[[VAL_0]]) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = mhlo.power %arg0, %arg0 : tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @floordiv_broadcast_i32(
// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<2x3xi32>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<3xi32>) -> tensor<2x3xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK:           %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) <{incompatible_shape_error = true}> : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
// CHECK:           %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Const"() <{value = dense<1> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK:           %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK:           %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK:           %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           return %[[VAL_16]] : tensor<2x3xi32>
// CHECK:         }
func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
  %0 = mhlo.constant dense<0> : tensor<2x3xi32>
  %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = #chlo<comparison_direction LT>} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
  %2 = mhlo.constant dense<0> : tensor<3xi32>
  %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo<comparison_direction LT>} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
  %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
  %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
  %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
  %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
  %8 = mhlo.constant dense<1> : tensor<3xi32>
  %9 = mhlo.subtract %7, %8 : tensor<3xi32>
  %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array<i64: 1>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
  %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
  %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
  %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = array<i64: 1>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
  %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
  func.return %14 : tensor<2x3xi32>
}

// CHECK-LABEL:   func @floordiv_reverse_broadcast_i32(
// CHECK-SAME:                                         %[[VAL_0:.*]]: tensor<3xi32>,
// CHECK-SAME:                                         %[[VAL_1:.*]]: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK:           %[[VAL_3:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Less"(%[[VAL_1]], %[[VAL_4]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
// CHECK:           %[[VAL_6:.*]] = "tf.Equal"(%[[VAL_3]], %[[VAL_5]]) <{incompatible_shape_error = true}> : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
// CHECK:           %[[VAL_7:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_8:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK:           %[[VAL_9:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK:           %[[VAL_11:.*]] = "tf.Sub"(%[[VAL_9]], %[[VAL_10]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_12:.*]] = "tf.AddV2"(%[[VAL_8]], %[[VAL_11]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_13:.*]] = "tf.Neg"(%[[VAL_12]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_14:.*]] = "tf.Abs"(%[[VAL_1]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_15:.*]] = "tf.Div"(%[[VAL_13]], %[[VAL_14]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           %[[VAL_16:.*]] = "tf.Select"(%[[VAL_6]], %[[VAL_7]], %[[VAL_15]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK:           return %[[VAL_16]] : tensor<2x3xi32>
// CHECK:         }
func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
  %0 = mhlo.constant dense<0> : tensor<3xi32>
  %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
  %2 = mhlo.constant dense<0> : tensor<2x3xi32>
  %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = #chlo<comparison_direction LT>} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
  %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
  %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
  %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
  %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
  %8 = mhlo.constant dense<1> : tensor<2x3xi32>
  %9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
  %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
  %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
  %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
  %13 = mhlo.divide %11, %12 : tensor<2x3xi32>
  %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
  func.return %14 : tensor<2x3xi32>
}

// CHECK-LABEL:   func @floordiv_f32(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_3]] : tensor<2xf32>
// CHECK:         }
func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
  %1 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
  %2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %2 : tensor<2xf32>
}

// CHECK-LABEL:   func @floordiv_f16_broadcast(
// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<2x3xf16>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<3xf16>) -> tensor<2x3xf16> {
// CHECK:           %[[VAL_2:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
// CHECK:           %[[VAL_3:.*]] = "tf.Div"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
// CHECK:           %[[VAL_4:.*]] = "tf.FloorDiv"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
// CHECK:           return %[[VAL_4]] : tensor<2x3xf16>
// CHECK:         }
func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
  %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
  %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
  %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
  func.return %2 : tensor<2x3xf16>
}

// CHECK-LABEL:   func @equal(
// CHECK-SAME:                %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @equal_dynamic(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
// CHECK:           return %[[VAL_2]] : tensor<?xi1>
// CHECK:         }
func.func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL:   func @equal_broadcast(
// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @equal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @equal_broadcast_chlo(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @equal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @equal_broadcast_no_incompatible_shapes_error(
// CHECK-SAME:                                                       %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                                                       %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @equal_incompatible_shape_broadcastable(
// CHECK-SAME:                                                 %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                                                 %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
// CHECK:           return %[[VAL_2]] : tensor<?xi1>
// CHECK:         }
func.func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL: func @equal_unsupported_compare_type
func.func @equal_unsupported_compare_type(%arg0: tensor<1xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xi1> {
  // CHECK: chlo.broadcast_compare
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, compare_type = #chlo<comparison_type TOTALORDER>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @notequal(
// CHECK-SAME:                   %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                   %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @notequal_broadcast(
// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @notequal_broadcast(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @notequal_broadcast_chlo(
// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @notequal_broadcast_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction NE>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @notequal_broadcast_no_incompatible_shapes_error(
// CHECK-SAME:                                                          %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                                                          %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction NE>} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @notequal_incompatible_shape_broadcastable(
// CHECK-SAME:                                                    %[[VAL_0:.*]]: tensor<?xi32>,
// CHECK-SAME:                                                    %[[VAL_1:.*]]: tensor<1xi32>) -> tensor<?xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
// CHECK:           return %[[VAL_2]] : tensor<?xi1>
// CHECK:         }
func.func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = #chlo<comparison_direction NE>} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL:   func @greater(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @broadcast_greater(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_greater(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @broadcast_greater_chlo(
// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_greater_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL: func @greater_unsupported_compare_type
func.func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> {
  // CHECK: mhlo.compare
  %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @greater_equal(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @broadcast_greater_equal(
// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_greater_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction GE>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @broadcast_greater_equal_chlo(
// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_greater_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction GE>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @less(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:               %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @broadcast_less(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_less(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @broadcast_less_chlo(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_less_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction LT>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @less_equal(
// CHECK-SAME:                     %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME:                     %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_2]] : tensor<2xi1>
// CHECK:         }
func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @broadcast_less_equal(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_less_equal(%arg0: tensor<1x1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x2xi32>
  %1 = "mhlo.compare"(%0, %arg1) {comparison_direction = #mhlo<comparison_direction LE>} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %1 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @broadcast_less_equal_chlo(
// CHECK-SAME:                                    %[[VAL_0:.*]]: tensor<1xi32>,
// CHECK-SAME:                                    %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK:           %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
// CHECK:           return %[[VAL_2]] : tensor<1x2xi1>
// CHECK:         }
func.func @broadcast_less_equal_chlo(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
  %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = array<i64: 1>, comparison_direction = #chlo<comparison_direction LE>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
  func.return %0 : tensor<1x2xi1>
}

// CHECK-LABEL:   func @concat_v2(
// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<3x3xf32>,
// CHECK-SAME:                    %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<6x3xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
// CHECK:           return %[[VAL_3]] : tensor<6x3xf32>
// CHECK:         }
func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
  %2 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 0 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
  func.return %2 : tensor<6x3xf32>
}

// CHECK-LABEL:   func @concat_v2_1d_axis(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<3x3xf32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<3x3xf32>) -> tensor<3x6xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[VAL_3:.*]] = "tf.ConcatV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32>
// CHECK:           return %[[VAL_3]] : tensor<3x6xf32>
// CHECK:         }
func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
  %2 = "mhlo.concatenate"(%arg0, %arg1) <{dimension = 1 : i64}> : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
  func.return %2 : tensor<3x6xf32>
}

// CHECK-LABEL:   func @const() -> tensor<2xi32> {
// CHECK:           %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:           return %[[VAL_0]] : tensor<2xi32>
// CHECK:         }
func.func @const() -> tensor<2xi32> {
  %0 = mhlo.constant dense<0> : tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @relu(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           return %[[VAL_2]] : tensor<1xi32>
// CHECK:         }
func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array<i64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
  func.return %1 : tensor<1xi32>
}

// CHECK-LABEL:   func @relu_unranked(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xi32>) -> tensor<?xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK:           return %[[VAL_2]] : tensor<?xi32>
// CHECK:         }
func.func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = array<i64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
  func.return %1 : tensor<?xi32>
}

// CHECK-LABEL:   func @relu6(
// CHECK-SAME:                %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           return %[[VAL_4]] : tensor<1xi32>
// CHECK:         }
func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = mhlo.constant dense<6> : tensor<i32>
  %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array<i64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
  %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array<i64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
  func.return %3 : tensor<1xi32>
}

// CHECK-LABEL:   func @relu6_unranked(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?xi32>) -> tensor<?xi32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_3:.*]] = "tf.Minimum"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK:           %[[VAL_4:.*]] = "tf.Maximum"(%[[VAL_3]], %[[VAL_1]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK:           return %[[VAL_4]] : tensor<?xi32>
// CHECK:         }
func.func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = mhlo.constant dense<6> : tensor<i32>
  %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = array<i64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
  %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = array<i64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
  func.return %3 : tensor<?xi32>
}

// CHECK-LABEL:   func @relu_grad(
// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<4x8xf32>,
// CHECK-SAME:                    %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<4x8xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK:           %[[VAL_3:.*]] = "tf.Greater"(%[[VAL_1]], %[[VAL_2]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<4x8xf32>}> : () -> tensor<4x8xf32>
// CHECK:           %[[VAL_5:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_0]], %[[VAL_4]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK:           return %[[VAL_5]] : tensor<4x8xf32>
// CHECK:         }
func.func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
  %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = array<i64>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
  %3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
  func.return %3 : tensor<4x8xf32>
}

// CHECK-LABEL:   func @select(
// CHECK-SAME:                 %[[VAL_0:.*]]: tensor<2xi1>,
// CHECK-SAME:                 %[[VAL_1:.*]]: tensor<2xi32>,
// CHECK-SAME:                 %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_3]] : tensor<2xi32>
// CHECK:         }
func.func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @select_float(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<2xi1>,
// CHECK-SAME:                       %[[VAL_1:.*]]: tensor<2xf32>,
// CHECK-SAME:                       %[[VAL_2:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_3]] : tensor<2xf32>
// CHECK:         }
func.func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @select_multidimensional(
// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<3x2xi1>,
// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<3x2xi32>,
// CHECK-SAME:                                  %[[VAL_2:.*]]: tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK:           return %[[VAL_3]] : tensor<3x2xi32>
// CHECK:         }
func.func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
  func.return %0 : tensor<3x2xi32>
}

// CHECK-LABEL:   func @selectv2(
// CHECK-SAME:                   %[[VAL_0:.*]]: tensor<2xi1>,
// CHECK-SAME:                   %[[VAL_1:.*]]: tensor<2xi32>,
// CHECK-SAME:                   %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_3]] : tensor<2xi32>
// CHECK:         }
func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @selectv2_pred_scalar(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<i1>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<2xi32>,
// CHECK-SAME:                               %[[VAL_2:.*]]: tensor<2xi32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Select"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_3]] : tensor<2xi32>
// CHECK:         }
func.func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
  %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @selectv2_broadcasted_operand(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<i1>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x1xi32>,
// CHECK-SAME:                               %[[VAL_2:.*]]: tensor<1x100xi32>) -> tensor<1x100xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.SelectV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<i1>, tensor<1x1xi32>, tensor<1x100xi32>) -> tensor<1x100xi32>
// CHECK:           return %[[VAL_3]] : tensor<1x100xi32>
// CHECK:         }
func.func @selectv2_broadcasted_operand(%arg0: tensor<i1>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x100xi32>) -> tensor<1x100xi32> {
  %0 = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi32>) -> tensor<1x100xi32>
  %1 = "mhlo.select"(%arg0, %0, %arg2) : (tensor<i1>, tensor<1x100xi32>, tensor<1x100xi32>) -> tensor<1x100xi32>
  func.return %1 : tensor<1x100xi32>
}

// CHECK-LABEL:   func @selectv2_broadcasted_condition(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<1x1xi1>,
// CHECK-SAME:                               %[[VAL_1:.*]]: tensor<1x100xi32>,
// CHECK-SAME:                               %[[VAL_2:.*]]: tensor<1x100xi32>) -> tensor<1x100xi32> {
// CHECK:           %[[VAL_3:.*]] = "tf.SelectV2"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<1x1xi1>, tensor<1x100xi32>, tensor<1x100xi32>) -> tensor<1x100xi32>
// CHECK:           return %[[VAL_3]] : tensor<1x100xi32>
// CHECK:         }
func.func @selectv2_broadcasted_condition(%arg0: tensor<1x1xi1>, %arg1: tensor<1x100xi32>, %arg2: tensor<1x100xi32>) -> tensor<1x100xi32> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xi1>) -> tensor<1x100xi1>
  %1 = "mhlo.select"(%0, %arg1, %arg2) : (tensor<1x100xi1>, tensor<1x100xi32>, tensor<1x100xi32>) -> tensor<1x100xi32>
  func.return %1 : tensor<1x100xi32>
}

// CHECK-LABEL:   func @transpose_2d(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<3x2xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
// CHECK:           return %[[VAL_4]] : tensor<3x2xf32>
// CHECK:         }
func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
  %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
  %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<3x2xf32>
  func.return %2 : tensor<3x2xf32>
}

// CHECK-LABEL:   func @transpose_3d_int32(
// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
// CHECK:           return %[[VAL_4]] : tensor<3x2x1xf32>
// CHECK:         }
func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
  %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32>
  %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
  func.return %2 : tensor<3x2x1xf32>
}

// CHECK-LABEL:   func @transpose_3d(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
// CHECK:           return %[[VAL_4]] : tensor<3x2x1xf32>
// CHECK:         }
func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
  %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
  %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[2, 1, 0]> : tensor<3xi64>}> : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
  func.return %2 : tensor<3x2x1xf32>
}

// CHECK-LABEL:   func @transpose_dynamic_2d(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<?x4xf32>) -> tensor<4x?xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_3]]) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32>
// CHECK:           return %[[VAL_4]] : tensor<4x?xf32>
// CHECK:         }
func.func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
  %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
  %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
  %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0]> : tensor<2xi64>}> : (tensor<?x4xf32>) -> tensor<4x?xf32>
  func.return %2 : tensor<4x?xf32>
}

// CHECK-LABEL:   func @abs(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @abs_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Abs"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @ceil(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @ceil_dynamic(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Ceil"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @complex_abs(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.ComplexAbs"(%[[VAL_0]]) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
  %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @cos(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @cos_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Cos"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @exp(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @exp_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Exp"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @floor(
// CHECK-SAME:                %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @floor_dynamic(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Floor"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @is_finite(
// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi1> {
// CHECK:           %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi1>
// CHECK:           return %[[VAL_1]] : tensor<2xi1>
// CHECK:         }
func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
  %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
  func.return %0 : tensor<2xi1>
}

// CHECK-LABEL:   func @is_finite_dynamic(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xi1> {
// CHECK:           %[[VAL_1:.*]] = "tf.IsFinite"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xi1>
// CHECK:           return %[[VAL_1]] : tensor<?xi1>
// CHECK:         }
func.func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
  %0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
  func.return %0 : tensor<?xi1>
}

// CHECK-LABEL:   func @log(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @log_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Log"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @log1p(
// CHECK-SAME:                %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @log1p_dynamic(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Log1p"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @neg(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @neg_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Neg"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @sigmoid(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.Mul"(%[[VAL_0]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           %[[VAL_5:.*]] = "tf.Tanh"(%[[VAL_4]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.Mul"(%[[VAL_5]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           %[[VAL_7:.*]] = "tf.AddV2"(%[[VAL_6]], %[[VAL_3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_7]] : tensor<2xf32>
// CHECK:         }
func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = mhlo.constant dense<5.000000e-01> : tensor<f32>
  %1 = mhlo.constant dense<2> : tensor<1xi64>
  %2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
  %3 = mhlo.multiply %arg0, %2 : tensor<2xf32>
  %4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
  %5 = mhlo.multiply %4, %2 : tensor<2xf32>
  %6 = mhlo.add %5, %2 : tensor<2xf32>
  func.return %6 : tensor<2xf32>
}

// CHECK-LABEL:   func @sin(
// CHECK-SAME:              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @sin_dynamic(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Sin"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @rsqrt(
// CHECK-SAME:                %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @rsqrt_dynamic(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Rsqrt"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @sqrt(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @sqrt_dynamic(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Sqrt"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @tanh(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @tanh_dynamic(
// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Tanh"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @bitcast(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @bitcast_dynamic(
// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK:           return %[[VAL_1]] : tensor<?xf32>
// CHECK:         }
func.func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL:   func @bitcast_same_widths(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xi32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Bitcast"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xi32>
// CHECK:           return %[[VAL_1]] : tensor<2xi32>
// CHECK:         }
func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
  %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
  func.return %0 : tensor<2xi32>
}

// CHECK-LABEL:   func @sign(
// CHECK-SAME:               %[[VAL_0:.*]]: tensor<1x2x3x4xf32>,
// CHECK-SAME:               %[[VAL_1:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1x2x3x4xf32>}> : () -> tensor<1x2x3x4xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) <{incompatible_shape_error = true}> : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1x2x3x4xf32>}> : () -> tensor<1x2x3x4xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK:           %[[VAL_7:.*]] = "tf.Select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK:           %[[VAL_8:.*]] = "tf.Select"(%[[VAL_2]], %[[VAL_3]], %[[VAL_7]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK:           return %[[VAL_8]] : tensor<1x2x3x4xf32>
// CHECK:         }
func.func @sign(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
  %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
  %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
  %3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
  %4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
  %5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
  %6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
  func.return %6 : tensor<1x2x3x4xf32>
}

// CHECK-LABEL:   func @size_rank_one_i32(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<f32>) -> tensor<i32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           return %[[VAL_1]] : tensor<i32>
// CHECK:         }
func.func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
  %0 = mhlo.constant dense<1> : tensor<i32>
  func.return %0 : tensor<i32>
}

// CHECK-LABEL:   func @size_rank_one_i64(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<f32>) -> tensor<i64> {
// CHECK:           %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           return %[[VAL_1]] : tensor<i64>
// CHECK:         }
func.func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
  %0 = mhlo.constant dense<1> : tensor<i64>
  func.return %0 : tensor<i64>
}

// CHECK-LABEL:   func @complex(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<3xf32>,
// CHECK-SAME:                  %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
// CHECK:           %[[VAL_2:.*]] = "tf.Complex"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
// CHECK:           return %[[VAL_2]] : tensor<3xcomplex<f32>>
// CHECK:         }
func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
  %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
  func.return %0 : tensor<3xcomplex<f32>>
}

// CHECK-LABEL:   func @convert_i32_f32(
// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.Cast"(%[[VAL_0]]) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xf32>
// CHECK:           return %[[VAL_1]] : tensor<2xf32>
// CHECK:         }
func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
  %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
  func.return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @convert_slice(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<[0, 4153]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 4672]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.StridedSlice"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]])
// CHECK-SAME:          <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}>
// CHECK-SAME:          (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32>
// CHECK:           return %[[VAL_4]] : tensor<1x519xf32>
// CHECK:         }
func.func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
  %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x4672xf32>) -> tensor<1x519xf32>
  func.return %0 : tensor<1x519xf32>
}

// CHECK-LABEL:   func @reshape(
// CHECK-SAME:                  %[[VAL_0:.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[2, 2, 6]> : tensor<3xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32>
// CHECK:           return %[[VAL_2]] : tensor<2x2x6xf32>
// CHECK:         }
func.func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
  %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
  func.return %0 : tensor<2x2x6xf32>
}

// CHECK-LABEL:   func @dynamic_reshape(
// CHECK-SAME:                         %arg0: tensor<1x1x1x?xf32>,
// CHECK-SAME:                         %arg1: tensor<3xi32>,
// CHECK-SAME:                         %arg2: tensor<1xi32>) -> tensor<?xf32> {
// CHECK:           %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<1x1x1x?xf32>, tensor<3xi32>) -> tensor<1x1x?xf32>
// CHECK:           %1 = "tf.Reshape"(%0, %arg2) : (tensor<1x1x?xf32>, tensor<1xi32>) -> tensor<?xf32>
// CHECK:           return %1 : tensor<?xf32>
func.func @dynamic_reshape(%arg0: tensor<1x1x1x?xf32>, %arg1: tensor<3xi32>, %arg2: tensor<1xi32>) -> tensor<?xf32> {
  %0 = mhlo.dynamic_reshape %arg0, %arg1 : (tensor<1x1x1x?xf32>, tensor<3xi32>) -> tensor<1x1x?xf32>
  %1 = mhlo.dynamic_reshape %0, %arg2 : (tensor<1x1x?xf32>, tensor<1xi32>) -> tensor<?xf32>
  func.return %1 : tensor<?xf32>
}

// CHECK-LABEL: func @round_nearest_even(
// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<2xf32> {
// CHECK:         %[[VAL_1:.*]] = "tf.Round"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK:         return %[[VAL_1]] : tensor<2xf32>
// CHECK:       }
func.func @round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> {
  %0 = "mhlo.round_nearest_even"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
  return %0 : tensor<2xf32>
}

// CHECK-LABEL:   func @convert_dot_2d_1d(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x256xf32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<1xf32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
// CHECK:           %[[VAL_5:.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           return %[[VAL_6]] : tensor<1xf32>
// CHECK:         }
func.func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> {
  %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
  func.return %0 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_dot_1d_1d(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<256xf32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<256xf32>) -> tensor<f32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[1, 256]> : tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
// CHECK:           %[[VAL_4:.*]] = arith.constant dense<[256, 1]> : tensor<2xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_4]]) : (tensor<256xf32>, tensor<2xi64>) -> tensor<256x1xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_5]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
// CHECK:           %[[VAL_7:.*]] = arith.constant dense<> : tensor<0xi64>
// CHECK:           %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor<f32>
// CHECK:           return %[[VAL_8]] : tensor<f32>
// CHECK:         }
func.func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> {
  %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
  func.return %0 : tensor<f32>
}

// CHECK-LABEL:   func @convert_dot_2d_2d(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<1x256xf32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.BatchMatMulV3"(%[[VAL_0]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x1xf32>
// CHECK:         }
func.func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
  %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
  func.return %0 : tensor<1x1xf32>
}

// CHECK-LABEL:   func @broadcast_in_dim_tf_style(
// CHECK-SAME:                                    %[[VAL_0:.*]]: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[3, 8, 8, 16]> : tensor<4xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[VAL_1]]) : (tensor<8x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<3x8x8x16xf32>
// CHECK:         }
func.func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"}> : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
  func.return %0 : tensor<3x8x8x16xf32>
}

// CHECK-LABEL:   func @broadcast_in_dim_general_case(
// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
// CHECK:           %[[VAL_1:.*]] = arith.constant dense<[3, 1, 1, 16]> : tensor<4xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_1]]) : (tensor<3x1x16xf32>, tensor<4xi64>) -> tensor<3x1x1x16xf32>
// CHECK:           %[[VAL_3:.*]] = arith.constant dense<[3, 8, 8, 16]> : tensor<4xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.BroadcastTo"(%[[VAL_2]], %[[VAL_3]]) : (tensor<3x1x1x16xf32>, tensor<4xi64>) -> tensor<3x8x8x16xf32>
// CHECK:           return %[[VAL_4]] : tensor<3x8x8x16xf32>
// CHECK:         }
func.func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
  %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"}> : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
  func.return %0 : tensor<3x8x8x16xf32>
}

// CHECK-LABEL: func @dynamic_broadcast_in_dim_tf_style(
// CHECK-SAME:                               %[[ARG_0:.*]]: tensor<?x1x1x2x1xf32>,
// CHECK-SAME:                               %[[ARG_1:.*]]: tensor<5xi32>) -> tensor<?x750x1x2x384xf32> {
// CHECK          %[[VAL_0:.*]] = "tf.BroadcastTo"(%[[ARG_0]], %[[ARG_1]]) : (tensor<?x1x1x2x1xf32>, tensor<5xi32>) -> tensor<?x750x1x2x384xf32>
// CHECK          return %[[VAL_0]] : tensor<?x750x1x2x384xf32>
func.func @dynamic_broadcast_in_dim_tf_style(%arg0: tensor<?x1x1x2x1xf32>, %arg1: tensor<5xi32>) -> tensor<?x750x1x2x384xf32> {
  %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : (tensor<?x1x1x2x1xf32>, tensor<5xi32>) -> tensor<?x750x1x2x384xf32>
  func.return %0 : tensor<?x750x1x2x384xf32>
}

// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_back_dims(
// CHECK-SAME:                               %[[ARG_0:.*]]: tensor<?x3000xf32>,
// CHECK-SAME:                               %[[ARG_1:.*]]: tensor<4xi32>) -> tensor<?x3000x2x4xf32> {
// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
// CHECK          %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor<?x3000xf32>, tensor<i64>) -> tensor<?x3000x1xf32>
// CHECK          %[[CST_1:.*]] = "tf.Const"() <{value = dense<3> : tensor<i64>}> : () -> tensor<i64>
// CHECK          %[[VAL_1:.*]] = "tf.ExpandDims"(%[[VAL_0]], %[[CST_1]]) : (tensor<?x3000x1xf32>, tensor<i64>) -> tensor<?x3000x1x1xf32>
// CHECK          %[[VAL_2:.*]] = "tf.BroadcastTo"(%[[VAL_1]], %[[ARG_1]]) : (tensor<?x3000x1x1xf32>, tensor<4xi32>) -> tensor<?x3000x2x4xf32>
// CHECK          return %[[VAL_2]] : tensor<?x3000x2x4xf32>
func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor<?x3000xf32>, %arg1: tensor<4xi32>) -> tensor<?x3000x2x4xf32> {
  %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<?x3000xf32>, tensor<4xi32>) -> tensor<?x3000x2x4xf32>
  func.return %0 : tensor<?x3000x2x4xf32>
}

// CHECK-LABEL: func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(
// CHECK-SAME:                               %[[ARG_0:.*]]: tensor<?x750x768xf32>,
// CHECK-SAME:                               %[[ARG_1:.*]]: tensor<4xi32>) -> tensor<?x750x1x768xf32> {
// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<2> : tensor<i64>}> : () -> tensor<i64>
// CHECK          %[[VAL_0:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_0]]) : (tensor<?x750x768xf32>, tensor<i64>) -> tensor<?x750x1x768xf32>
// CHECK          %[[VAL_1:.*]] = "tf.BroadcastTo"(%[[VAL_0]], %[[ARG_1]]) : (tensor<?x750x1x768xf32>, tensor<4xi32>) -> tensor<?x750x1x768xf32>
// CHECK          return %[[VAL_1]] : tensor<?x750x1x768xf32>
func.func @dynamic_broadcast_in_dim_general_case_expand_middle_dim(%arg0: tensor<?x750x768xf32>, %arg1: tensor<4xi32>) -> tensor<?x750x1x768xf32> {
  %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<[0, 1, 3]> : tensor<3xi64>}> : (tensor<?x750x768xf32>, tensor<4xi32>) -> tensor<?x750x1x768xf32>
  func.return %0 : tensor<?x750x1x768xf32>
}

// CHECK-LABEL:   func @convert_dot_general(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<3x2x6x5x1xf32>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 3, 4, 1, 2]> : tensor<5xi64>}> : () -> tensor<5xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<3x2x6x5x1xf32>, tensor<5xi64>) -> tensor<3x5x1x2x6xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x2x4x6xf32>, tensor<4xi64>) -> tensor<3x2x6x4xf32>
// CHECK:           %[[VAL_6:.*]] = arith.constant dense<[3, 5, 12]> : tensor<3xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_6]]) : (tensor<3x5x1x2x6xf32>, tensor<3xi64>) -> tensor<3x5x12xf32>
// CHECK:           %[[VAL_8:.*]] = arith.constant dense<[3, 12, 4]> : tensor<3xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Reshape"(%[[VAL_5]], %[[VAL_8]]) : (tensor<3x2x6x4xf32>, tensor<3xi64>) -> tensor<3x12x4xf32>
// CHECK:           %[[VAL_10:.*]] = "tf.BatchMatMulV3"(%[[VAL_7]], %[[VAL_9]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32>
// CHECK:           %[[VAL_11:.*]] = arith.constant dense<[3, 5, 1, 4]> : tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Reshape"(%[[VAL_10]], %[[VAL_11]]) : (tensor<3x5x4xf32>, tensor<4xi64>) -> tensor<3x5x1x4xf32>
// CHECK:           return %[[VAL_12]] : tensor<3x5x1x4xf32>
// CHECK:         }
func.func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<
      lhs_batching_dimensions = [0],
      lhs_contracting_dimensions = [1, 2],
      rhs_batching_dimensions = [0],
      rhs_contracting_dimensions = [1, 3]
    >,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
  func.return %0 : tensor<3x5x1x4xf32>
}

// CHECK-LABEL:   func @quantized_dot_general_not_converted
// CHECK:           "mhlo.dot_general"
func.func @quantized_dot_general_not_converted(%arg0: tensor<1x1x512xf32>, %arg1: tensor<512x512x!quant.uniform<i8:f32, 0.00285>>) -> tensor<1x1x512xf32> {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<
      lhs_contracting_dimensions = [2],
      rhs_contracting_dimensions = [0]
    >} : (tensor<1x1x512xf32>, tensor<512x512x!quant.uniform<i8:f32, 0.00285>>) -> tensor<1x1x512xf32>
  func.return %0 : tensor<1x1x512xf32>
}

// CHECK-LABEL:   func @convert_dot_general_repeated(
// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<1x1x1024xf32>,
// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[1, 1024]> : tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<1x1024xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : {{.*}} -> tensor<1x1024xf32>
// CHECK:           %[[VAL_5:.*]] = arith.constant dense<[1, 1, 1024]> : tensor<3xi64>
// CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : {{.*}} -> tensor<1x1x1024xf32>
// CHECK:           return %[[VAL_6]] : tensor<1x1x1024xf32>
// CHECK:         }
func.func @convert_dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<
      lhs_batching_dimensions = [],
      lhs_contracting_dimensions = [2],
      rhs_batching_dimensions = [],
      rhs_contracting_dimensions = [0]
    >,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<1x1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1x1024xf32>
  func.return %0 : tensor<1x1x1024xf32>
}

// CHECK-LABEL:   func @convert_dot_general_int8(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<256xi8>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<256x8xi8>) -> tensor<8xi32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[1, 256]> : tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<256xi8>, tensor<2xi64>) -> tensor<1x256xi8>
// CHECK:           %[[VAL_4:.*]] = "tf.BatchMatMulV3"(%[[VAL_3]], %[[VAL_1]]) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<1x256xi8>, tensor<256x8xi8>) -> tensor<1x8xi32>
// CHECK:           %[[VAL_5:.*]] = arith.constant dense<8> : tensor<1xi64>
// CHECK:           %[[VAL_6:.*]] = "tf.Reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x8xi32>, tensor<1xi64>) -> tensor<8xi32>
// CHECK:           return %[[VAL_6]] : tensor<8xi32>
// CHECK:         }
func.func @convert_dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi8>) -> tensor<8xi32> {
  %0 = "mhlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #mhlo.dot<
      lhs_contracting_dimensions = [0],
      rhs_contracting_dimensions = [0]>,
      precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
  } : (tensor<256xi8>, tensor<256x8xi8>) -> tensor<8xi32>
  func.return %0 : tensor<8xi32>
}

// CHECK-LABEL:   func @convert_dot_general_dynamic_rhs_out_dim(
// CHECK-SAME:                              %arg0: tensor<4x4x256xf32>,
// CHECK-SAME:                              %arg1: tensor<4x?x256xf32>) -> tensor<4x4x?xf32> {
// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<4x?x256xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
// CHECK:           %1 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %4 = "tf.Concat"(%cst_4, %cst_3, %3, %2) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %5 = "tf.Reshape"(%0, %4) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32>
// CHECK:           %6 = "tf.BatchMatMulV3"(%arg0, %5) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32>
// CHECK:           %7 = "tf.Shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32>
// CHECK:           %8 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %9 = "tf.Gather"(%7, %cst_5) <{validate_indices = true}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32>
// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %10 = "tf.Gather"(%8, %cst_6) <{validate_indices = true}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %11 = "tf.Concat"(%cst_7, %9, %10) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %12 = "tf.Reshape"(%6, %11) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
// CHECK:           return %12 : tensor<4x4x?xf32>
// CHECK:           }
func.func @convert_dot_general_dynamic_rhs_out_dim(%arg0: tensor<4x4x256xf32>, %arg1: tensor<4x?x256xf32>) -> tensor<4x4x?xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
  dot_dimension_numbers = #mhlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [2]
  >} : (tensor<4x4x256xf32>, tensor<4x?x256xf32>) -> tensor<4x4x?xf32>
func.return %0 : tensor<4x4x?xf32>
}

// CHECK-LABEL:   func @convert_dot_general_dynamic_batch_dim(
// CHECK-SAME:                              %arg0: tensor<2x?x2x3xf32>,
// CHECK-SAME:                              %arg1: tensor<2x?x4x3xf32>) -> tensor<2x?x2x4xf32> {
// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<2x?x4x3xf32>, tensor<4xi64>) -> tensor<2x?x3x4xf32>
// CHECK:           %1 = "tf.Shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %4 = "tf.Gather"(%1, %cst_3) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %5 = "tf.Concat"(%cst_4, %4, %2, %3) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           %6 = "tf.Reshape"(%arg0, %5) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32>
// CHECK:           %7 = "tf.Shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %8 = "tf.UnsortedSegmentProd"(%7, %cst_5, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %9 = "tf.UnsortedSegmentProd"(%7, %cst_6, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %10 = "tf.Gather"(%7, %cst_8) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
// CHECK-DAG:       %cst_9 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %11 = "tf.Concat"(%cst_9, %10, %9, %8) : (tensor<i32>, tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           %12 = "tf.Reshape"(%0, %11) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32>
// CHECK:           %13 = "tf.BatchMatMulV3"(%6, %12) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32>
// CHECK:           %14 = "tf.Shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32>
// CHECK:           %15 = "tf.Shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %16 = "tf.Gather"(%14, %cst_10) <{validate_indices = true}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
// CHECK:           %cst_11 = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %17 = "tf.Gather"(%15, %cst_11) <{validate_indices = true}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG:       %cst_12 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %18 = "tf.Concat"(%cst_12, %16, %17) : (tensor<i32>, tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           %19 = "tf.Reshape"(%13, %18) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32>
// CHECK:           return %19 : tensor<2x?x2x4xf32>
// CHECK:           }
func.func @convert_dot_general_dynamic_batch_dim(%arg0: tensor<2x?x2x3xf32>, %arg1: tensor<2x?x4x3xf32>) -> tensor<2x?x2x4xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
  dot_dimension_numbers = #mhlo.dot<
    lhs_batching_dimensions = [0, 1],
    rhs_batching_dimensions = [0, 1],
    lhs_contracting_dimensions = [3],
    rhs_contracting_dimensions = [3]
  >} : (tensor<2x?x2x3xf32>, tensor<2x?x4x3xf32>) -> tensor<2x?x2x4xf32>
func.return %0 : tensor<2x?x2x4xf32>
}

// CHECK-LABEL:   func @convert_dot_general_dynamic_lhs_rhs_out_dims(
// CHECK-SAME:                              %arg0: tensor<2x2x?x3xf32>,
// CHECK-SAME:                              %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> {
// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %0 = "tf.Transpose"(%arg1, %cst) : (tensor<2x4x?x3xf32>, tensor<4xi64>) -> tensor<2x3x4x?xf32>
// CHECK:           %1 = "tf.Shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %2 = "tf.UnsortedSegmentProd"(%1, %cst_0, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %3 = "tf.UnsortedSegmentProd"(%1, %cst_1, %cst_2) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %4 = "tf.Concat"(%cst_4, %cst_3, %2, %3) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %5 = "tf.Reshape"(%arg0, %4) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32>
// CHECK:           %6 = "tf.Shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %7 = "tf.UnsortedSegmentProd"(%6, %cst_5, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %8 = "tf.UnsortedSegmentProd"(%6, %cst_6, %cst_7) : (tensor<4xi32>, tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK-DAG:       %cst_9 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %9 = "tf.Concat"(%cst_9, %cst_8, %8, %7) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %10 = "tf.Reshape"(%0, %9) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
// CHECK:           %11 = "tf.BatchMatMulV3"(%5, %10) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32>
// CHECK:           %12 = "tf.Shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32>
// CHECK:           %13 = "tf.Shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32>
// CHECK-DAG:       %cst_10 = "tf.Const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %14 = "tf.Gather"(%12, %cst_10) <{validate_indices = true}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32>
// CHECK-DAG:       %cst_11 = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %15 = "tf.Gather"(%13, %cst_11) <{validate_indices = true}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32>
// CHECK-DAG:       %cst_12 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %16 = "tf.Concat"(%cst_12, %14, %15) : (tensor<i32>, tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
// CHECK:           %17 = "tf.Reshape"(%11, %16) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32>
// CHECK:           return %17 : tensor<2x2x?x4x?xf32>
// CHECK:           }
func.func @convert_dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
  dot_dimension_numbers = #mhlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [3],
    rhs_contracting_dimensions = [3]
  >} : (tensor<2x2x?x3xf32>, tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32>
func.return %0 : tensor<2x2x?x4x?xf32>
}

// CHECK-LABEL:   func @convert_dot_general_dynamic_contracting_dim(
// CHECK-SAME:                              %arg0: tensor<4x4x?xf32>,
// CHECK-SAME:                              %arg1: tensor<4x?x256xf32>) -> tensor<4x4x256xf32> {
// CHECK:           %0 = "tf.Shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32>
// CHECK-DAG:       %cst = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_0 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_1 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %1 = "tf.UnsortedSegmentProd"(%0, %cst, %cst_1) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %2 = "tf.UnsortedSegmentProd"(%0, %cst_0, %cst_1) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_2 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK-DAG:       %cst_3 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %3 = "tf.Concat"(%cst_3, %cst_2, %1, %2) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %4 = "tf.Reshape"(%arg0, %3) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32>
// CHECK:           %5 = "tf.Shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32>
// CHECK-DAG:       %cst_4 = "tf.Const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_5 = "tf.Const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK-DAG:       %cst_6 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %6 = "tf.UnsortedSegmentProd"(%5, %cst_4, %cst_6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK:           %7 = "tf.UnsortedSegmentProd"(%5, %cst_5, %cst_6) : (tensor<3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG:       %cst_7 = "tf.Const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK-DAG:       %cst_8 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %8 = "tf.Concat"(%cst_8, %cst_7, %7, %6) : (tensor<i32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
// CHECK:           %9 = "tf.Reshape"(%arg1, %8) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32>
// CHECK:           %10 = "tf.BatchMatMulV3"(%4, %9) <{adj_x = false, adj_y = false, grad_x = false, grad_y = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
// CHECK:           return %10 : tensor<4x4x256xf32>
// CHECK:           }
func.func @convert_dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, %arg1: tensor<4x?x256xf32>) -> tensor<4x4x256xf32> {
%0 = "mhlo.dot_general"(%arg0, %arg1) {
  dot_dimension_numbers = #mhlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >} : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32>
func.return %0 : tensor<4x4x256xf32>
}



// CHECK-LABEL:   func.func @convert_conv1d(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<16x32x256xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
// CHECK:           return %[[VAL_14]] : tensor<16x32x256xbf16>
// CHECK:         }
func.func @convert_conv1d(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16>
  func.return %0 : tensor<16x32x256xbf16>
}

// CHECK-LABEL:   func.func @convert_conv1d_dynamic_batch(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<?x32x256xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[-9223372036854775808, 32, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?x32x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<?x32x256x1xbf16>, tensor<4xi64>) -> tensor<?x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<?x32x1x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[-9223372036854775808, 32, 256]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<?x32x256x1xbf16>, tensor<3xi64>) -> tensor<?x32x256xbf16>
// CHECK:           return %[[VAL_14]] : tensor<?x32x256xbf16>
// CHECK:         }
func.func @convert_conv1d_dynamic_batch(%arg0: tensor<?x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<?x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16>
  func.return %0 : tensor<?x32x256xbf16>
}

// CHECK-LABEL: convert_dynamic_1d_group_conv
func.func private @convert_dynamic_1d_group_conv(%arg1: tensor<?x768x2xf32>, %arg2: tensor<768x48x128xf32>) -> (tensor<?x768x3xf32>) {
  %0 = mhlo.convolution(%arg1, %arg2) 
    dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], 
    window = {pad = [[64, 64]]} 
      {batch_group_count = 1 : i64, feature_group_count = 16 : i64}
        : (tensor<?x768x2xf32>, tensor<768x48x128xf32>) -> tensor<?x768x3xf32>
  return %0 : tensor<?x768x3xf32>
// CHECK:  %cst = arith.constant dense<[-9223372036854775808, 768, 2, 1]> : tensor<4xi64>
// CHECK:  %0 = "tf.Reshape"(%arg0, %cst) : (tensor<?x768x2xf32>, tensor<4xi64>) -> tensor<?x768x2x1xf32>
// CHECK:  %cst_0 = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:  %1 = "tf.Transpose"(%0, %cst_0) : (tensor<?x768x2x1xf32>, tensor<4xi64>) -> tensor<?x2x1x768xf32>
// CHECK:  %cst_1 = arith.constant dense<[768, 48, 128, 1]> : tensor<4xi64>
// CHECK:  %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<768x48x128xf32>, tensor<4xi64>) -> tensor<768x48x128x1xf32>
// CHECK:  %cst_2 = "tf.Const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:  %3 = "tf.Transpose"(%2, %cst_2) : (tensor<768x48x128x1xf32>, tensor<4xi64>) -> tensor<128x1x48x768xf32>
// CHECK:  %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 64, 64, 0, 0, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x2x1x768xf32>, tensor<128x1x48x768xf32>) -> tensor<?x3x1x768xf32>
// CHECK:  %cst_3 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:  %5 = "tf.Transpose"(%4, %cst_3) : (tensor<?x3x1x768xf32>, tensor<4xi64>) -> tensor<?x768x3x1xf32>
// CHECK:  %cst_4 = arith.constant dense<[-9223372036854775808, 768, 3]> : tensor<3xi64>
// CHECK:  %6 = "tf.Reshape"(%5, %cst_4) : (tensor<?x768x3x1xf32>, tensor<3xi64>) -> tensor<?x768x3xf32>
// CHECK:  return %6 : tensor<?x768x3xf32>
}

// CHECK-LABEL:   func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<16x32x256xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
// CHECK:           return %[[VAL_14]] : tensor<16x32x256xbf16>
// CHECK:         }
func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    padding = dense<0> : tensor<1x2xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16>
  func.return %0 : tensor<16x32x256xbf16>
}

// CHECK-LABEL:   func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf_dynamic_batch(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<?x32x256xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[-9223372036854775808, 32, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<?x32x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<?x32x256x1xbf16>, tensor<4xi64>) -> tensor<?x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<?x32x1x256xbf16>
// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<?x32x1x256xbf16>, tensor<4xi64>) -> tensor<?x32x256x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[-9223372036854775808, 32, 256]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<?x32x256x1xbf16>, tensor<3xi64>) -> tensor<?x32x256xbf16>
// CHECK:           return %[[VAL_14]] : tensor<?x32x256xbf16>
// CHECK:         }
func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf_dynamic_batch(%arg0: tensor<?x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    padding = dense<0> : tensor<1x2xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<?x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<?x32x256xbf16>
  func.return %0 : tensor<?x32x256xbf16>
}



// CHECK-LABEL:   func.func @convert_conv1d_non_canonical_dimension_numbers(
// CHECK-SAME:                                                              %[[VAL_0:.*]]: tensor<32x16x256xbf16>,
// CHECK-SAME:                                                              %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[32, 16, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<32x16x256xbf16>, tensor<4xi64>) -> tensor<32x16x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<32x16x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK:           %[[VAL_6:.*]] = arith.constant dense<[256, 1, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<256x1x256xbf16>, tensor<4xi64>) -> tensor<256x1x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[1, 3, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<256x1x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[3, 0, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<256x16x32x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[256, 16, 32]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<256x16x32x1xbf16>, tensor<3xi64>) -> tensor<256x16x32xbf16>
// CHECK:           return %[[VAL_14]] : tensor<256x16x32xbf16>
// CHECK:         }
func.func @convert_conv1d_non_canonical_dimension_numbers(%arg0: tensor<32x16x256xbf16>, %arg1: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[0, b, f]x[o, 0, i]->[f, b, 0]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<32x16x256xbf16>, tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16>
  func.return %0 : tensor<256x16x32xbf16>
}

// CHECK-LABEL:   func.func @no_convert_conv1d_dynamic(
// CHECK-SAME:                                         %[[VAL_0:.*]]: tensor<16x?x256xbf16>,
// CHECK-SAME:                                         %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> {
// CHECK:           %[[VAL_2:.*]] = mhlo.convolution(%[[VAL_0]], %[[VAL_1]]) dim_numbers = [b, 0, f]x[0, i, o]->[b, 0, f], window = {stride = [1], pad = {{\[\[}}0, 0]], lhs_dilate = [1], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<16x?x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16>
// CHECK:           return %[[VAL_2]] : tensor<16x?x256xbf16>
// CHECK:         }
func.func @no_convert_conv1d_dynamic(%arg0: tensor<16x?x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<16x?x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x?x256xbf16>
  func.return %0 : tensor<16x?x256xbf16>
}

// CHECK-LABEL:   func.func @convert_conv1d_feature_group_gt_1(
// CHECK:    %cst = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
// CHECK:    %0 = "tf.Reshape"(%arg0, %cst) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK:    %cst_0 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:    %1 = "tf.Transpose"(%0, %cst_0) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK:    %cst_1 = arith.constant dense<[1, 128, 128, 1]> : tensor<4xi64>
// CHECK:    %2 = "tf.Reshape"(%arg1, %cst_1) : (tensor<1x128x128xbf16>, tensor<4xi64>) -> tensor<1x128x128x1xbf16>
// CHECK:    %cst_2 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:    %3 = "tf.Transpose"(%2, %cst_2) : (tensor<1x128x128x1xbf16>, tensor<4xi64>) -> tensor<1x1x128x128xbf16>
// CHECK:    %4 = "tf.Conv2D"(%1, %3) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x128x128xbf16>) -> tensor<16x32x1x128xbf16>
// CHECK:    %cst_3 = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:    %5 = "tf.Transpose"(%4, %cst_3) : (tensor<16x32x1x128xbf16>, tensor<4xi64>) -> tensor<16x32x128x1xbf16>
// CHECK:    %cst_4 = arith.constant dense<[16, 32, 128]> : tensor<3xi64>
// CHECK:    %6 = "tf.Reshape"(%5, %cst_4) : (tensor<16x32x128x1xbf16>, tensor<3xi64>) -> tensor<16x32x128xbf16>
// CHECK:    return %6 : tensor<16x32x128xbf16>
func.func @convert_conv1d_feature_group_gt_1(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 2 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>,
    window_strides = dense<1> : tensor<1xi64>
  } : (tensor<16x32x256xbf16>, tensor<1x128x128xbf16>) -> tensor<16x32x128xbf16>
  func.return %0 : tensor<16x32x128xbf16>
}

// CHECK-LABEL:   func.func @convert_conv1d_missing_windows_strides_fallback(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<16x32x256xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16>
// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16>
// CHECK-DAG:       %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16>
// CHECK:           %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16>
// CHECK:           %[[VAL_11:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16>
// CHECK:           %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16>
// CHECK:           return %[[VAL_14]] : tensor<16x32x256xbf16>
// CHECK:         }
func.func @convert_conv1d_missing_windows_strides_fallback(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<1xi64>,
    padding = dense<0> : tensor<1x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<1xi64>
  } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16>
  func.return %0 : tensor<16x32x256xbf16>
}

// CHECK-LABEL:   func.func @convert_conv1d_missing_windows_strides_fallback_2(
// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1x64x64x4xbf16>,
// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16>
// CHECK:           return %[[VAL_2]] : tensor<1x62x62x320xbf16>
// CHECK:         }
func.func @convert_conv1d_missing_windows_strides_fallback_2(%arg0: tensor<1x64x64x4xbf16>, %arg1: tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
    feature_group_count = 1 : i64,
    lhs_dilation = dense<[1, 1]> : tensor<2xi64>,
    padding = dense<0> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<[1, 1]> : tensor<2xi64>
  } : (tensor<1x64x64x4xbf16>, tensor<3x3x4x320xbf16>) -> tensor<1x62x62x320xbf16>
  func.return %0 : tensor<1x62x62x320xbf16>
}

// CHECK-LABEL:   func @convert_conv2d(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
// CHECK:         }
func.func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
  func.return %0 : tensor<1x8x8x16xf32>
}

// CHECK-LABEL:   func @convert_group_conv2d(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x14x14x2240xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
// CHECk:           return %[[VAL_2]] : tensor<1x7x7x2240xf32>
// CHECK:         }
func.func @convert_group_conv2d(%arg0: tensor<1x14x14x2240xf32>, %arg1: tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 :i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
    feature_group_count = 20 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<2> : tensor<2xi64>} :
    (tensor<1x14x14x2240xf32>, tensor<3x3x112x2240xf32>) -> tensor<1x7x7x2240xf32>
  func.return %0 : tensor<1x7x7x2240xf32>
}

// CHECK-LABEL:    func.func @convert_transpose_conv_with_transpose(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x256x64x64xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<2x2x64x256xf32>) -> tensor<1x64x128x128xf32> {
// CHECK:            %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:            %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0:.*]], %[[VAL_2:.*]]) : (tensor<1x256x64x64xf32>, tensor<4xi64>) -> tensor<1x64x64x256xf32>
// CHECK:            %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:            %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_1:.*]], %[[VAL_4:.*]]) : (tensor<2x2x64x256xf32>, tensor<2xi64>) -> tensor<2x2x64x256xf32>
// CHECK:            %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[1, 128, 128, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
// CHECK:            %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:            %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7:.*]], %[[VAL_8:.*]]) : (tensor<1x128x128x64xf32>, tensor<4xi64>) -> tensor<1x64x128x128xf32>
// CHECK:            return %[[VAL_9:.*]] : tensor<1x64x128x128xf32>
// CHECK:           }

func.func @convert_transpose_conv_with_transpose(%arg0: tensor<1x256x64x64xf32>, %arg1: tensor<2x2x64x256xf32>) -> tensor<1x64x128x128xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
  dimension_numbers = #mhlo.conv<[b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1]>,
  feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<1> : tensor<2xi64>} :
  (tensor<1x256x64x64xf32>, tensor<2x2x64x256xf32>) -> tensor<1x64x128x128xf32>
  func.return %0 : tensor<1x64x128x128xf32>
}

// CHECK-LABEL:    func.func @convert_transpose_conv_with_transpose2(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<64x64x1x256xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<2x2x64x256xf32>) -> tensor<128x128x1x64xf32> {
// CHECK:            %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:            %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0:.*]], %[[VAL_2:.*]]) : (tensor<64x64x1x256xf32>, tensor<4xi64>) -> tensor<1x64x64x256xf32>
// CHECK:            %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:            %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_1:.*]], %[[VAL_4:.*]]) : (tensor<2x2x64x256xf32>, tensor<2xi64>) -> tensor<2x2x64x256xf32>
// CHECK:            %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[1, 128, 128, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:            %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6:.*]], %[[VAL_5:.*]], %[[VAL_3:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<2x2x64x256xf32>, tensor<1x64x64x256xf32>) -> tensor<1x128x128x64xf32>
// CHECK:            %[[VAL_8:.*]] = "tf.Const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:            %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7:.*]], %[[VAL_8:.*]]) : (tensor<1x128x128x64xf32>, tensor<4xi64>) -> tensor<128x128x1x64xf32>
// CHECK:            return %[[VAL_9:.*]] : tensor<128x128x1x64xf32>
// CHECK:           }

func.func @convert_transpose_conv_with_transpose2(%arg0: tensor<64x64x1x256xf32>, %arg1: tensor<2x2x64x256xf32>) -> tensor<128x128x1x64xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
  dimension_numbers = #mhlo.conv<[0, 1, b, f]x[0, 1, o, i]->[0, 1, b, f]>,
  feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense<false> : tensor<2xi1>, window_strides = dense<1> : tensor<2xi64>} :
  (tensor<64x64x1x256xf32>, tensor<2x2x64x256xf32>) -> tensor<128x128x1x64xf32>
  func.return %0 : tensor<128x128x1x64xf32>
}


// CHECK-LABEL:   func @convert_conv2d_dynamic_batch(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?x8x8x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<?x8x8x16xf32>
// CHECK:         }
func.func @convert_conv2d_dynamic_batch(%arg0: tensor<?x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
  func.return %0 : tensor<?x8x8x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_no_padding(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x6x6x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x6x6x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x4x4x16xf32>
// CHECK:         }
func.func @convert_conv2d_no_padding(%arg0: tensor<1x6x6x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x6x6x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x4x4x16xf32>
  func.return %0 : tensor<1x4x4x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_no_rhs_dilation(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
// CHECK:         }
func.func @convert_conv2d_no_rhs_dilation(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
  func.return %0 : tensor<1x8x8x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_no_window_strides(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
// CHECK:         }
func.func @convert_conv2d_no_window_strides(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
  func.return %0 : tensor<1x8x8x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_no_lhs_dilation(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x8x8x16xf32>
// CHECK:         }
func.func @convert_conv2d_no_lhs_dilation(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
  func.return %0 : tensor<1x8x8x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_with_transpose(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x8x1x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<8x8x1x207xf32>, tensor<4xi64>) -> tensor<1x8x8x207xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x3x16x207xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
// CHECK:           %[[VAL_7:.*]] = "tf.Const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_8:.*]] = "tf.Transpose"(%[[VAL_6]], %[[VAL_7]]) : (tensor<1x8x8x16xf32>, tensor<4xi64>) -> tensor<16x8x8x1xf32>
// CHECK:           return %[[VAL_8]] : tensor<16x8x8x1xf32>
// CHECK:         }
func.func @convert_conv2d_with_transpose(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 2,
      input_feature_dimension = 3,
      input_spatial_dimensions = [0, 1],
      kernel_input_feature_dimension = 3,
      kernel_output_feature_dimension = 2,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 3,
      output_feature_dimension = 0,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32>
  func.return %0 : tensor<16x8x8x1xf32>
}

// CHECK-LABEL:   func @convert_conv2d_with_transpose_dynamic_batch(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x8x?x207xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x16x207xf32>) -> tensor<16x8x8x?xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<8x8x?x207xf32>, tensor<4xi64>) -> tensor<?x8x8x207xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_4]]) : (tensor<3x3x16x207xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.Conv2D"(%[[VAL_3]], %[[VAL_5]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x8x8x16xf32>
// CHECK:           %[[VAL_7:.*]] = "tf.Const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_8:.*]] = "tf.Transpose"(%[[VAL_6]], %[[VAL_7]]) : (tensor<?x8x8x16xf32>, tensor<4xi64>) -> tensor<16x8x8x?xf32>
// CHECK:           return %[[VAL_8]] : tensor<16x8x8x?xf32>
// CHECK:         }
func.func @convert_conv2d_with_transpose_dynamic_batch(%arg0: tensor<8x8x?x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x?xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 2,
      input_feature_dimension = 3,
      input_spatial_dimensions = [0, 1],
      kernel_input_feature_dimension = 3,
      kernel_output_feature_dimension = 2,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 3,
      output_feature_dimension = 0,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<8x8x?x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x?xf32>
  func.return %0 : tensor<16x8x8x?xf32>
}

// CHECK-LABEL:   func @convert_conv2d_explicit_padding(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<64x8x8x8xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<64x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32>
// CHECK:           return %[[VAL_2]] : tensor<64x3x3x64xf32>
// CHECK:         }
func.func @convert_conv2d_explicit_padding(%arg0: tensor<64x8x8x8xf32>, %arg1: tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<64x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<64x3x3x64xf32>
  func.return %0 : tensor<64x3x3x64xf32>
}

// CHECK-LABEL:   func @convert_conv2d_explicit_padding_dynamic_batch(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<?x8x8x8xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [0, 0, 1, 1, 1, 1, 0, 0], padding = "EXPLICIT", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32>
// CHECK:           return %[[VAL_2]] : tensor<?x3x3x64xf32>
// CHECK:         }
func.func @convert_conv2d_explicit_padding_dynamic_batch(%arg0: tensor<?x8x8x8xf32>, %arg1: tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<?x8x8x8xf32>, tensor<8x8x8x64xf32>) -> tensor<?x3x3x64xf32>
  func.return %0 : tensor<?x3x3x64xf32>
}

// CHECK-LABEL:   func @convert_conv2d_negative_explicit_padding(
// CHECK-SAME:                         %[[ARG0:.*]]: tensor<128x7x9x64xf32>,
// CHECK-SAME:                         %[[ARG1:.*]]: tensor<3x2x64x4xf32>) -> tensor<128x4x3x4xf32> {
// CHECK-DAG:       %[[START:.*]] = "tf.Const"() <{value = dense<[0, 0, 5, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[128, 5, 4, 64]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[SLICED_ARG0:.*]] = "tf.Slice"(%[[ARG0]], %[[START]], %[[SIZE]])
// CHECK-SAME:      (tensor<128x7x9x64xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<128x5x4x64xf32>
// CHECK:           %[[CONV:.*]] = "tf.Conv2D"(%[[SLICED_ARG0]], %[[ARG1]])
// CHECK-SAME:      explicit_paddings = [0, 0, 4, 0, 0, 2, 0, 0]
// CHECK-SAME:      (tensor<128x5x4x64xf32>, tensor<3x2x64x4xf32>) -> tensor<128x4x3x4xf32>
// CHECK:           return %[[CONV]] : tensor<128x4x3x4xf32>
// CHECK:         }
func.func @convert_conv2d_negative_explicit_padding(%arg0: tensor<128x7x9x64xf32>, %arg1: tensor<3x2x64x4xf32>) -> tensor<128x4x3x4xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<[[4, -2], [-5, 2]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>
  } : (tensor<128x7x9x64xf32>, tensor<3x2x64x4xf32>) -> tensor<128x4x3x4xf32>
  func.return %0 : tensor<128x4x3x4xf32>
}

// CHECK-LABEL:   func @convert_conv2d_negative_explicit_padding_dynamic_batch(
// CHECK-SAME:                         %[[ARG0:.*]]: tensor<?x7x9x64xf32>,
// CHECK-SAME:                         %[[ARG1:.*]]: tensor<3x2x64x4xf32>) -> tensor<?x4x3x4xf32> {
// CHECK-DAG:       %[[START:.*]] = "tf.Const"() <{value = dense<[0, 0, 5, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[-9223372036854775808, 5, 4, 64]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[SLICED_ARG0:.*]] = "tf.Slice"(%[[ARG0]], %[[START]], %[[SIZE]])
// CHECK-SAME:      (tensor<?x7x9x64xf32>, tensor<4xi64>, tensor<4xi64>) -> tensor<?x5x4x64xf32>
// CHECK:           %[[CONV:.*]] = "tf.Conv2D"(%[[SLICED_ARG0]], %[[ARG1]])
// CHECK-SAME:      explicit_paddings = [0, 0, 4, 0, 0, 2, 0, 0]
// CHECK-SAME:      (tensor<?x5x4x64xf32>, tensor<3x2x64x4xf32>) -> tensor<?x4x3x4xf32>
// CHECK:           return %[[CONV]] : tensor<?x4x3x4xf32>
// CHECK:         }
func.func @convert_conv2d_negative_explicit_padding_dynamic_batch(%arg0: tensor<?x7x9x64xf32>, %arg1: tensor<3x2x64x4xf32>) -> tensor<?x4x3x4xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<[[4, -2], [-5, 2]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<2> : tensor<2xi64>
  } : (tensor<?x7x9x64xf32>, tensor<3x2x64x4xf32>) -> tensor<?x4x3x4xf32>
  func.return %0 : tensor<?x4x3x4xf32>
}

// CHECK-LABEL:   func @convert_depthwise_conv2d(
// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
// CHECK:           %[[CST:.*]] = arith.constant dense<[3, 3, 207, 16]> : tensor<4xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST]]) : (tensor<3x3x1x3312xf32>, tensor<4xi64>) -> tensor<3x3x207x16xf32>
// CHECK:           %[[VAL_3:.*]] = "tf.DepthwiseConv2dNative"(%[[VAL_0]], %[[VAL_2]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x3312xf32>
// CHECK:           return %[[VAL_3]] : tensor<1x8x8x3312xf32>
// CHECK:         }
func.func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x1x3312xf32>) -> tensor<1x8x8x3312xf32>
  func.return %0 : tensor<1x8x8x3312xf32>
}

// CHECK-LABEL:   func @convert_depthwise_transposed_conv2d
func.func @convert_depthwise_transposed_conv2d(%arg0: tensor<1x2x20x20xf32>, %arg1: tensor<8x8x2x1xf32>) -> (tensor<1x2x80x80xf32>) {
  %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {pad = [[5, 5], [5, 5]], lhs_dilate = [4, 4]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x20x20xf32>, tensor<8x8x2x1xf32>) -> tensor<1x2x80x80xf32>
  return %0 : tensor<1x2x80x80xf32>
  // CHECK:  %cst = "tf.Const"() <{value = dense<0> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK-DAG:  %cst_0 = "tf.Const"() <{value = dense<[1, 1, 20, 20]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %cst_1 = "tf.Const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x2x20x20xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<1x1x20x20xf32>
  // CHECK:  %cst_2 = "tf.Const"() <{value = dense<0> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK-DAG:  %cst_3 = "tf.Const"() <{value = dense<[8, 8, 1, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %cst_4 = "tf.Const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %1 = "tf.StridedSlice"(%arg1, %cst_2, %cst_3, %cst_4) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<8x8x2x1xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<8x8x1x1xf32>
  // CHECK:  %cst_5 = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %2 = "tf.Transpose"(%0, %cst_5) : (tensor<1x1x20x20xf32>, tensor<4xi64>) -> tensor<1x20x20x1xf32>
  // CHECK:  %cst_6 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
  // CHECK:  %3 = "tf.ReverseV2"(%1, %cst_6) : (tensor<8x8x1x1xf32>, tensor<2xi64>) -> tensor<8x8x1x1xf32>
  // CHECK:  %cst_7 = "tf.Const"() <{value = dense<[1, 80, 80, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
  // CHECK:  %4 = "tf.Conv2DBackpropInput"(%cst_7, %3, %2) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 4, 4, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<8x8x1x1xf32>, tensor<1x20x20x1xf32>) -> tensor<1x80x80x1xf32>
  // CHECK:  %cst_8 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %5 = "tf.Transpose"(%4, %cst_8) : (tensor<1x80x80x1xf32>, tensor<4xi64>) -> tensor<1x1x80x80xf32>
  // CHECK:  %cst_9 = "tf.Const"() <{value = dense<[0, 1, 0, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK-DAG:  %cst_10 = "tf.Const"() <{value = dense<[1, 2, 20, 20]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %cst_11 = "tf.Const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %6 = "tf.StridedSlice"(%arg0, %cst_9, %cst_10, %cst_11) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x2x20x20xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<1x1x20x20xf32>
  // CHECK:  %cst_12 = "tf.Const"() <{value = dense<[0, 0, 1, 0]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK-DAG:  %cst_13 = "tf.Const"() <{value = dense<[8, 8, 2, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %cst_14 = "tf.Const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %7 = "tf.StridedSlice"(%arg1, %cst_12, %cst_13, %cst_14) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<8x8x2x1xf32>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<8x8x1x1xf32>
  // CHECK:  %cst_15 = "tf.Const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %8 = "tf.Transpose"(%6, %cst_15) : (tensor<1x1x20x20xf32>, tensor<4xi64>) -> tensor<1x20x20x1xf32>
  // CHECK:  %cst_16 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
  // CHECK:  %9 = "tf.ReverseV2"(%7, %cst_16) : (tensor<8x8x1x1xf32>, tensor<2xi64>) -> tensor<8x8x1x1xf32>
  // CHECK:  %cst_17 = "tf.Const"() <{value = dense<[1, 80, 80, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
  // CHECK:  %10 = "tf.Conv2DBackpropInput"(%cst_17, %9, %8) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 4, 4, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<8x8x1x1xf32>, tensor<1x20x20x1xf32>) -> tensor<1x80x80x1xf32>
  // CHECK:  %cst_18 = "tf.Const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
  // CHECK:  %11 = "tf.Transpose"(%10, %cst_18) : (tensor<1x80x80x1xf32>, tensor<4xi64>) -> tensor<1x1x80x80xf32>
  // CHECK:  %cst_19 = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
  // CHECK:  %12 = "tf.ConcatV2"(%5, %11, %cst_19) : (tensor<1x1x80x80xf32>, tensor<1x1x80x80xf32>, tensor<i64>) -> tensor<1x2x80x80xf32>
  // CHECK:  return %12 : tensor<1x2x80x80xf32>
}

// CHECK-LABEL:   func @convert_conv2d_to_resize(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x56x624x16xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x257x16x1xf32>) -> tensor<1x56x904x16xf32> {
// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[56, 904]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x624x16xf32>, tensor<2xi32>) -> tensor<1x56x904x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x56x904x16xf32>
// CHECK:         }
func.func @convert_conv2d_to_resize(%arg0: tensor<1x56x624x16xf32>, %arg1: tensor<1x257x16x1xf32>) -> tensor<1x56x904x16xf32> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>,
    feature_group_count = 16 : i64,
    lhs_dilation = dense<[1, 129]> : tensor<2xi64>,
    padding = dense<[[0, 0], [128, 128]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<1> : tensor<2xi64>,
    window_strides = dense<[1, 89]> : tensor<2xi64>} : (tensor<1x56x624x16xf32>, tensor<1x257x16x1xf32>) -> tensor<1x56x904x16xf32>
  func.return %0 : tensor<1x56x904x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_resize_perferred(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x56x1248x16xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x1x16x1xf32>) -> tensor<1x111x1248x16xf32> {
// CHECK-DAG:       %[[SIZE:.*]] = "tf.Const"() <{value = dense<[111, 1248]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:           %[[VAL_2:.*]] = "tf.ResizeBilinear"(%[[VAL_0]], %[[SIZE]]) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x1248x16xf32>, tensor<2xi32>) -> tensor<1x111x1248x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x111x1248x16xf32>
// CHECK:         }
func.func @convert_conv2d_resize_perferred(%arg0: tensor<1x56x1248x16xf32>, %arg1: tensor<3x1x16x1xf32>) -> tensor<1x111x1248x16xf32> {
	%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>,
    feature_group_count = 16 : i64,
    lhs_dilation = dense<[2, 1]> : tensor<2xi64>,
    padding = dense<[[1, 1], [0, 0]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>],
    rhs_dilation = dense<[1, 1]> : tensor<2xi64>,
    window_strides = dense<[1, 1]> : tensor<2xi64>} : (tensor<1x56x1248x16xf32>, tensor<3x1x16x1xf32>) -> tensor<1x111x1248x16xf32>
  func.return %0 : tensor<1x111x1248x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_back_prop_input_same_pad(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x256x256x2xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<4x4x2x2xf32>) -> tensor<1x512x512x2xf32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.ReverseV2"(%[[VAL_1]], %[[VAL_3]]) : (tensor<4x4x2x2xf32>, tensor<2xi64>) -> tensor<4x4x2x2xf32>
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 512, 512, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_2]], %[[VAL_4]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<4x4x2x2xf32>, tensor<1x256x256x2xf32>) -> tensor<1x512x512x2xf32>
// CHECK:           return %[[VAL_5]] : tensor<1x512x512x2xf32>
// CHECK:         }
func.func @convert_conv2d_back_prop_input_same_pad(%arg0: tensor<1x256x256x2xf32>, %arg1: tensor<4x4x2x2xf32>) -> tensor<1x512x512x2xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<[[2, 2], [2, 2]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<1x256x256x2xf32>, tensor<4x4x2x2xf32>) -> tensor<1x512x512x2xf32>
  func.return %0 : tensor<1x512x512x2xf32>
}

// CHECK-LABEL:   func @convert_conv2d_back_prop_input_negative_pad(
// CHECK-NOT:       "tf.Conv2DBackpropInput"
func.func @convert_conv2d_back_prop_input_negative_pad(%arg0: tensor<1x256x256x2xf32>, %arg1: tensor<4x4x2x2xf32>) -> tensor<1x504x504x2xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<[[-2, -2], [-2, -2]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<1x256x256x2xf32>, tensor<4x4x2x2xf32>) -> tensor<1x504x504x2xf32>
  func.return %0 : tensor<1x504x504x2xf32>
}


// CHECK-LABEL:   func @convert_conv2d_back_prop_input(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x4x4x32xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x64x32xf32>) -> tensor<8x8x8x64xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.ReverseV2"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x3x64x32xf32>, tensor<2xi64>) -> tensor<3x3x64x32xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[8, 8, 8, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_4]], %[[VAL_3]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
// CHECK:           return %[[VAL_5]] : tensor<8x8x8x64xf32>
// CHECK:         }
func.func @convert_conv2d_back_prop_input(%arg0: tensor<8x4x4x32xf32>, %arg1: tensor<3x3x64x32xf32>) -> tensor<8x8x8x64xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 3,
      kernel_output_feature_dimension = 2,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<[[2, 1], [2, 1]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<8x4x4x32xf32>, tensor<3x3x64x32xf32>) -> tensor<8x8x8x64xf32>
  func.return %0 : tensor<8x8x8x64xf32>
}

// CHECK-LABEL:   func @convert_conv2d_back_prop_input_transpose_filter(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<8x4x4x32xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<3x3x32x64xf32>) -> tensor<8x8x8x64xf32> {
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Transpose"(%[[VAL_1]], %[[VAL_3]]) : (tensor<3x3x32x64xf32>, tensor<4xi64>) -> tensor<3x3x64x32xf32>
// CHECK:           %[[VAL_5:.*]] = "tf.ReverseV2"(%[[VAL_4]], %[[VAL_2]]) : (tensor<3x3x64x32xf32>, tensor<2xi64>) -> tensor<3x3x64x32xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[8, 8, 8, 64]> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:           %[[VAL_7:.*]] = "tf.Conv2DBackpropInput"(%[[VAL_6]], %[[VAL_5]], %[[VAL_0]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> : (tensor<4xi32>, tensor<3x3x64x32xf32>, tensor<8x4x4x32xf32>) -> tensor<8x8x8x64xf32>
// CHECK:           return %[[VAL_7]] : tensor<8x8x8x64xf32>
// CHECK:         }
func.func @convert_conv2d_back_prop_input_transpose_filter(%arg0: tensor<8x4x4x32xf32>, %arg1: tensor<3x3x32x64xf32>) -> tensor<8x8x8x64xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<2> : tensor<2xi64>, padding = dense<[[2, 1], [2, 1]]> : tensor<2x2xi64>,
    precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>
  } : (tensor<8x4x4x32xf32>, tensor<3x3x32x64xf32>) -> tensor<8x8x8x64xf32>
  func.return %0 : tensor<8x8x8x64xf32>
}

// CHECK-LABEL:   func @convert_conv2d_valid_padding(
// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<1x8x8x207xf32>,
// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<1x6x6x16xf32>
// CHECK:         }
func.func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32>
  func.return %0 : tensor<1x6x6x16xf32>
}

// CHECK-LABEL:   func @convert_conv2d_valid_padding_dynamic_batch(
// CHECK-SAME:                                       %[[VAL_0:.*]]: tensor<?x8x8x207xf32>,
// CHECK-SAME:                                       %[[VAL_1:.*]]: tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Conv2D"(%[[VAL_0]], %[[VAL_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> : (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32>
// CHECK:           return %[[VAL_2]] : tensor<?x6x6x16xf32>
// CHECK:         }
func.func @convert_conv2d_valid_padding_dynamic_batch(%arg0: tensor<?x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64,
    dimension_numbers = #mhlo.conv<raw
      input_batch_dimension = 0,
      input_feature_dimension = 3,
      input_spatial_dimensions = [1, 2],
      kernel_input_feature_dimension = 2,
      kernel_output_feature_dimension = 3,
      kernel_spatial_dimensions = [0, 1],
      output_batch_dimension = 0,
      output_feature_dimension = 3,
      output_spatial_dimensions = [1, 2]
    >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
       (tensor<?x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<?x6x6x16xf32>
  func.return %0 : tensor<?x6x6x16xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_prod(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           return %[[VAL_3]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_prod(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.multiply %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_sum(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           return %[[VAL_3]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.add %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_prod_non_constant_init(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x256xf32>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<f32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_1:.*]] = "tf.Prod"(%[[ARG_0]], %[[VAL_0]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           %[[VAL_2:.*]] = "tf.Mul"(%[[VAL_1]], %[[ARG_1]]) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
// CHECK:           return %[[VAL_2]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_prod_non_constant_init(%arg0: tensor<1x256xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
  %1 = "mhlo.reduce"(%arg0, %arg1) ({
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
    %2 = mhlo.multiply %arg2, %arg3 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}


// CHECK-LABEL:   func @convert_reduce_to_sum_non_constant_init(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x256xf32>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<f32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_1:.*]] = "tf.Sum"(%[[ARG_0]], %[[VAL_0]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           %[[VAL_2:.*]] = "tf.Add"(%[[VAL_1]], %[[ARG_1]]) : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
// CHECK:           return %[[VAL_2]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_sum_non_constant_init(%arg0: tensor<1x256xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
  %1 = "mhlo.reduce"(%arg0, %arg1) ({
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
    %2 = mhlo.add %arg2, %arg3 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_int_reduce_to_prod(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xi32>) -> tensor<1xi32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Prod"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK:           return %[[VAL_3]] : tensor<1xi32>
// CHECK:         }
func.func @convert_int_reduce_to_prod(%arg0: tensor<1x256xi32>) -> tensor<1xi32> {
  %0 = mhlo.constant dense<1> : tensor<i32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
    %2 = mhlo.multiply %arg1, %arg2 : tensor<i32>
    "mhlo.return"(%2) : (tensor<i32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xi32>, tensor<i32>) -> tensor<1xi32>
  func.return %1 : tensor<1xi32>
}


// CHECK-LABEL:   func @convert_int_reduce_to_sum(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xi32>) -> tensor<1xi32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Sum"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK:           return %[[VAL_3]] : tensor<1xi32>
// CHECK:         }
func.func @convert_int_reduce_to_sum(%arg0: tensor<1x256xi32>) -> tensor<1xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
    %2 = mhlo.add %arg1, %arg2 : tensor<i32>
    "mhlo.return"(%2) : (tensor<i32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xi32>, tensor<i32>) -> tensor<1xi32>
  func.return %1 : tensor<1xi32>
}

// CHECK-LABEL:   func @convert_reduce_to_max(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           return %[[VAL_3]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_max_int(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x4xi32>) -> tensor<1xi32> {
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK:           return %[[VAL_3]] : tensor<1xi32>
func.func @convert_reduce_to_max_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
  // -2147483648 is MIN for INT32
  %0 = mhlo.constant dense<-2147483648> : tensor<i32>
  %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [1] : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32>
   reducer(%arg2: tensor<i32>, %arg3: tensor<i32>)  {
    %892 = mhlo.maximum %arg2, %arg3 : tensor<i32>
    "mhlo.return"(%892) : (tensor<i32>) -> ()
  }
  func.return %1 : tensor<1xi32>
}

// CHECK-LABEL:   func @convert_reduce_to_min(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:           return %[[VAL_3]] : tensor<1xf32>
// CHECK:         }
func.func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
  // "0x7F800000" represents INF for f32.
  %0 = mhlo.constant dense<0x7F800000> : tensor<f32>
  %1 = "mhlo.reduce"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
  func.return %1 : tensor<1xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_min_int(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<1x4xi32>) -> tensor<1xi32> {
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_2]]) <{keep_dims = false}> : (tensor<1x4xi32>, tensor<1xi64>) -> tensor<1xi32>
// CHECK:           return %[[VAL_3]] : tensor<1xi32>
func.func @convert_reduce_to_min_int(%arg0: tensor<1x4xi32>) -> tensor<1xi32> {
  // 2147483647 is MAX for INT32
  %0 = mhlo.constant dense<2147483647> : tensor<i32>
  %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [1] : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32>
   reducer(%arg2: tensor<i32>, %arg3: tensor<i32>)  {
    %892 = mhlo.minimum %arg2, %arg3 : tensor<i32>
    "mhlo.return"(%892) : (tensor<i32>) -> ()
  }
  func.return %1 : tensor<1xi32>
}

// CHECK-LABEL:   func @convert_iota_1d() -> tensor<123xf32> {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.230000e+02> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK:           %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<123xf32>
// CHECK:           return %[[VAL_3]] : tensor<123xf32>
// CHECK:         }
func.func @convert_iota_1d() -> tensor<123xf32> {
  %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xf32>
  func.return %0 : tensor<123xf32>
}

// CHECK-LABEL:   func @convert_iota_3d() -> tensor<5x7x9xi32> {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<7> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<7xi32>
// CHECK:           %[[VAL_4:.*]] = "tf.Const"() <{value = dense<[1, 7, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.Reshape"(%[[VAL_3]], %[[VAL_4]]) : (tensor<7xi32>, tensor<3xi64>) -> tensor<1x7x1xi32>
// CHECK:           %[[VAL_6:.*]] = "tf.Const"() <{value = dense<[5, 7, 9]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_7:.*]] = "tf.BroadcastTo"(%[[VAL_5]], %[[VAL_6]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32>
// CHECK:           return %[[VAL_7]] : tensor<5x7x9xi32>
// CHECK:         }
func.func @convert_iota_3d() -> tensor<5x7x9xi32> {
  %0 = "mhlo.iota"() <{ iota_dimension = 1 : i64 }> : () -> tensor<5x7x9xi32>
  func.return %0 : tensor<5x7x9xi32>
}

// CHECK-LABEL:   func @convert_iota_ui64() -> tensor<123xui64> {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<ui64>}> : () -> tensor<ui64>
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<123> : tensor<ui64>}> : () -> tensor<ui64>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<ui64>}> : () -> tensor<ui64>
// CHECK:           %[[VAL_3:.*]] = "tf.Range"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (tensor<ui64>, tensor<ui64>, tensor<ui64>) -> tensor<123xui64>
// CHECK:           return %[[VAL_3]] : tensor<123xui64>
// CHECK:         }
func.func @convert_iota_ui64() -> tensor<123xui64> {
  %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xui64>
  func.return %0 : tensor<123xui64>
}

// CHECK-LABEL: func @no_convert_iota_ui8
func.func @no_convert_iota_ui8() -> tensor<123xui8> {
  // CHECK: "mhlo.iota"
  %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xui8>
  func.return %0 : tensor<123xui8>
}

// CHECK-LABEL:   func @convert_avgpool_valid(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK:         }
func.func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
  %0 = mhlo.constant dense<0.0> : tensor<f32>
  %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
  %2 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%5) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<0> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
  func.return %3 : tensor<4x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_broadcasted_divisor(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK:         }
func.func @convert_avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
  %0 = mhlo.constant dense<0.0> : tensor<f32>
  %1 = mhlo.constant dense<9.0> : tensor<f32>
  %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x7x7x8xf32>
  %3 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%5) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<0> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %4 = mhlo.divide %3, %2 : tensor<4x7x7x8xf32>
  func.return %4 : tensor<4x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_channel_first(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]}> : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x3x7x7xf32>
// CHECK:         }
func.func @convert_avgpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
  %0 = mhlo.constant dense<9.0> : tensor<4x3x7x7xf32>
  %1 = mhlo.constant dense<0.0> : tensor<f32>
  %2 = "mhlo.reduce_window"(%arg0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %4 = mhlo.add %arg1, %arg2 : tensor<f32>
    mhlo.return %4 : tensor<f32>
  }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<0> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
    window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7xf32>
  %3 = mhlo.divide %2, %0 : tensor<4x3x7x7xf32>
  func.return %3 : tensor<4x3x7x7xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_rw(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK:         }
func.func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
  %1 = mhlo.constant dense<0.0> : tensor<f32>
  %2 = "mhlo.reduce_window"(%arg0, %1) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %3 = "mhlo.reduce_window"(%0, %1) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
  func.return %4 : tensor<4x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_rw_broadcasted_const_lhs(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK:         }
func.func @convert_avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
  %0 = mhlo.constant dense<1.0> : tensor<f32>
  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor<f32>) -> tensor<4x16x16x8xf32>
  %2 = mhlo.constant dense<0.0> : tensor<f32>
  %3 = "mhlo.reduce_window"(%arg0, %2) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %4 = "mhlo.reduce_window"(%1, %2) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  %5 = mhlo.divide %3, %4 : tensor<4x7x7x8xf32>
  func.return %5 : tensor<4x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_3d(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) <{data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]}> : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
// CHECK:         }
func.func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
  %0 = mhlo.constant dense<0.0> : tensor<f32>
  %1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
  %2 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %5 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%5) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<5xi64>,
    padding = dense<0> : tensor<5x2xi64>,
    window_dilations = dense<1> : tensor<5xi64>,
    window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
    window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
  %3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
  func.return %3 : tensor<4x7x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_valid_3d_channel_first(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) <{data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]}> : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x3x7x7x7xf32>
// CHECK:         }
func.func @convert_avgpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
  %0 = mhlo.constant dense<27.0> : tensor<4x3x7x7x7xf32>
  %1 = mhlo.constant dense<0.0> : tensor<f32>
  %2 = "mhlo.reduce_window"(%arg0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %4 = mhlo.add %arg1, %arg2 : tensor<f32>
    mhlo.return %4 : tensor<f32>
  }) {
    base_dilations = dense<1> : tensor<5xi64>,
    padding = dense<0> : tensor<5x2xi64>,
    window_dilations = dense<1> : tensor<5xi64>,
    window_dimensions = dense<[1, 1, 3, 3, 3]> : tensor<5xi64>,
    window_strides = dense<[1, 1, 2, 2, 2]> : tensor<5xi64>} : (tensor<4x3x16x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7x7xf32>
  %3 = mhlo.divide %2, %0 : tensor<4x3x7x7x7xf32>
  func.return %3 : tensor<4x3x7x7x7xf32>
}

// CHECK-LABEL:   func @convert_avgpool_same(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK:         }
func.func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
  %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
  %1 = mhlo.constant dense<0.0> : tensor<f32>
  %2 = "mhlo.reduce_window"(%arg0, %1) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
  %3 = "mhlo.reduce_window"(%0, %1) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
  %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
  func.return %4 : tensor<4x8x8x8xf32>
}

// CHECK-LABEL:   func @convert_avgpool_reshape_broadcast(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK:         }
func.func @convert_avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<1x16x16x1xf32>
  %1 = mhlo.constant dense<0.000000e+00> : tensor<f32>
  %2 = "mhlo.reduce_window"(%arg0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %7 = mhlo.add %arg1, %arg2 : tensor<f32>
    mhlo.return %7 : tensor<f32>
  }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
  %3 = "mhlo.reduce_window"(%0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %7 = mhlo.add %arg1, %arg2 : tensor<f32>
    mhlo.return %7 : tensor<f32>
  }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor<f32>) -> tensor<1x8x8x1xf32>
  %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32>
  %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32>
  %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32>
  return %6 : tensor<4x8x8x8xf32>
}

// CHECK-LABEL:   func @convert_maxpool_valid(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK:         }
func.func @convert_maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %5 = mhlo.maximum %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%5) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<0> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
  func.return %1 : tensor<4x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_maxpool_valid_channel_first(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NCHW", explicit_paddings = [], ksize = [1, 1, 3, 3], padding = "VALID", strides = [1, 1, 2, 2]}> : (tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x3x7x7xf32>
// CHECK:         }
func.func @convert_maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
    mhlo.return %2 : tensor<f32>
  }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<0> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>,
    window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7xf32>
  func.return %1 : tensor<4x3x7x7xf32>
}

// CHECK-LABEL:   func @convert_maxpool_valid_3d(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool3D"(%[[VAL_0]]) <{data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]}> : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
// CHECK:         }
func.func @convert_maxpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %5 = mhlo.maximum %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%5) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<5xi64>,
    padding = dense<0> : tensor<5x2xi64>,
    window_dilations = dense<1> : tensor<5xi64>,
    window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
    window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
  func.return %1 : tensor<4x7x7x7x8xf32>
}

// CHECK-LABEL:   func @convert_maxpool_valid_3d_channel_first(
// CHECK-SAME:                                %[[VAL_0:.*]]: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
// CHECK:           %[[VAL_1:.*]]  = "tf.MaxPool3D"(%[[VAL_0]]) <{data_format = "NCDHW", ksize = [1, 1, 3, 3, 3], padding = "VALID", strides = [1, 1, 2, 2, 2]}> : (tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x3x7x7x7xf32>
// CHECK:         }
func.func @convert_maxpool_valid_3d_channel_first(%arg0: tensor<4x3x16x16x16xf32>) -> tensor<4x3x7x7x7xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
    mhlo.return %2 : tensor<f32>
  }) {
    base_dilations = dense<1> : tensor<5xi64>,
    padding = dense<0> : tensor<5x2xi64>,
    window_dilations = dense<1> : tensor<5xi64>,
    window_dimensions = dense<[1, 1, 3, 3, 3]> : tensor<5xi64>,
    window_strides = dense<[1, 1, 2, 2, 2]> : tensor<5xi64>} : (tensor<4x3x16x16x16xf32>, tensor<f32>) -> tensor<4x3x7x7x7xf32>
  func.return %1 : tensor<4x3x7x7x7xf32>
}

// CHECK-LABEL:   func @convert_maxpool_same(
// CHECK-SAME:                               %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// CHECK:           %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) <{data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
// CHECK:           return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK:         }
func.func @convert_maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
  // "0xFF800000" represents -INF for f32.
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %6 = mhlo.maximum %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%6) : (tensor<f32>) -> ()
    }) {
    base_dilations = dense<1> : tensor<4xi64>,
    padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
    window_dilations = dense<1> : tensor<4xi64>,
    window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
    window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
  func.return %1 : tensor<4x8x8x8xf32>
}

// CHECK-LABEL:   func @convert_pad(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<8x128xf32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<11x131xf32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.PadV2"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<8x128xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<11x131xf32>
// CHECK:           return %[[VAL_3]] : tensor<11x131xf32>
// CHECK:         }
func.func @convert_pad(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<11x131xf32> {
  %0 = "mhlo.pad"(%arg0, %arg1) {
    edge_padding_low = dense<[1, 0]> : tensor<2xi64>,
    edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
    interior_padding = dense<0> : tensor<2xi64>
  } : (tensor<8x128xf32>, tensor<f32>) -> tensor<11x131xf32>
  func.return %0 : tensor<11x131xf32>
}

// CHECK-LABEL:   func @convert_pad_negative_amount(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<8x128xf32>,
// CHECK-SAME:                      %[[VAL_1:.*]]: tensor<f32>) -> tensor<7x128xf32> {
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.PadV2"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<8x128xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<8x129xf32>
// CHECK:           %[[VAL_4:.*]] = arith.constant dense<[0, 1]> : tensor<2xi64>
// CHECK:           %[[VAL_5:.*]] = arith.constant dense<[7, 128]> : tensor<2xi64>
// CHECK:           %[[VAL_6:.*]] = "tf.Slice"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<8x129xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<7x128xf32>
// CHECK:           return %[[VAL_6]] : tensor<7x128xf32>
// CHECK:         }
func.func @convert_pad_negative_amount(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<7x128xf32> {
  %0 = "mhlo.pad"(%arg0, %arg1) {
    edge_padding_low = dense<[0, -1]> : tensor<2xi64>,
    edge_padding_high = dense<[-1, 1]> : tensor<2xi64>,
    interior_padding = dense<0> : tensor<2xi64>
  } : (tensor<8x128xf32>, tensor<f32>) -> tensor<7x128xf32>
  func.return %0 : tensor<7x128xf32>
}

// CHECK-LABEL:   func @convert_round(
// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<8x128xbf16>) -> tensor<8x128xbf16>
// CHECK:           %[[VAL_1:.*]] = "tf.Round"(%[[VAL_0]]) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
// CHECK:           return %[[VAL_1]]
// CHECK:         }
func.func @convert_round(%arg0: tensor<8x128xbf16>) -> tensor<8x128xbf16> {
  %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xbf16>
  %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xbf16>
  %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xbf16>
  %3 = "mhlo.floor"(%arg0) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
  %4 = mhlo.subtract %arg0, %3 : tensor<8x128xbf16>
  %5 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
  %6 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
  %7 = mhlo.multiply %arg0, %1 : tensor<8x128xbf16>
  %8 = "mhlo.floor"(%7) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
  %9 = mhlo.multiply %8, %0 : tensor<8x128xbf16>
  %10 = mhlo.subtract %3, %9 : tensor<8x128xbf16>
  %11 = "mhlo.compare"(%10, %2) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
  %12 = mhlo.and %6, %11 : tensor<8x128xi1>
  %13 = mhlo.or %5, %12 : tensor<8x128xi1>
  %14 = mhlo.add %3, %2 : tensor<8x128xbf16>
  %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xbf16>
  func.return %15 : tensor<8x128xbf16>
}

// CHECK-LABEL: func @convert_floor_mod_float
// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %arg1) : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_mod_float(%arg0: tensor<192x8xbf16>, %arg1: tensor<192x8xbf16>) -> tensor<192x8xbf16> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xbf16>
  %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xbf16>
  %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
  %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
  %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
  %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
  %6 = mhlo.and %4, %5 : tensor<192x8xi1>
  %7 = mhlo.add %1, %arg1 : tensor<192x8xbf16>
  %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
  func.return %8 : tensor<192x8xbf16>
}

// CHECK-LABEL: func @convert_floor_mod_int
// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %arg1) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_mod_int(%arg0: tensor<192x8xi32>, %arg1: tensor<192x8xi32>) -> tensor<192x8xi32> {
  %0 = mhlo.constant dense<0> : tensor<192x8xi32>
  %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xi32>
  %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
  %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
  %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1>
  %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
  %6 = mhlo.and %4, %5 : tensor<192x8xi1>
  %7 = mhlo.add %1, %arg1 : tensor<192x8xi32>
  %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
  func.return %8 : tensor<192x8xi32>
}

// CHECK-LABEL: func @convert_floor_mod_float_cst
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<192x8xbf16>}> : () -> tensor<192x8xbf16>
// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<192x8xbf16>}> : () -> tensor<192x8xbf16>
// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
// CHECK: return %[[RESULT]] : tensor<192x8xbf16>
// CHECK: }
func.func @convert_floor_mod_float_cst(%arg0: tensor<192x8xbf16>) -> tensor<192x8xbf16> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xbf16>
  %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xbf16>
  %2 = mhlo.remainder %arg0, %1 : tensor<192x8xbf16>
  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
  %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
  %5 = mhlo.and %3, %4 : tensor<192x8xi1>
  %6 = mhlo.add %2, %1 : tensor<192x8xbf16>
  %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
  func.return %7 : tensor<192x8xbf16>
}

// CHECK-LABEL: func @convert_floor_mod_int_cst
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<2> : tensor<192x8xi32>}> : () -> tensor<192x8xi32>
// CHECK-DAG: %[[CST2:.*]] = "tf.Const"() <{value = dense<2> : tensor<192x8xi32>}> : () -> tensor<192x8xi32>
// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
// CHECK: return %[[RESULT]] : tensor<192x8xi32>
// CHECK: }
func.func @convert_floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi32> {
  %0 = mhlo.constant dense<0> : tensor<192x8xi32>
  %1 = mhlo.constant dense<2> : tensor<192x8xi32>
  %2 = mhlo.remainder %arg0, %1 : tensor<192x8xi32>
  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
  %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1>
  %5 = mhlo.and %3, %4 : tensor<192x8xi1>
  %6 = mhlo.add %2, %1 : tensor<192x8xi32>
  %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32>
  func.return %7 : tensor<192x8xi32>
}

// CHECK-LABEL: func @convert_floor_mod_bfloat
// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %arg1) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_mod_bfloat(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xbf16>
  %1 = mhlo.remainder %arg0, %arg1 : tensor<10x10xbf16>
  %2 = mhlo.compare  NE, %1, %0,  FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %3 = mhlo.compare  LT, %1, %0,  FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %4 = mhlo.compare  LT, %arg1, %0,  FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %5 = mhlo.compare  NE, %3, %4,  UNSIGNED : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<10x10xi1>
  %6 = mhlo.and %5, %2 : tensor<10x10xi1>
  %7 = mhlo.add %1, %arg1 : tensor<10x10xbf16>
  %8 = mhlo.select %6, %7, %1 : tensor<10x10xi1>, tensor<10x10xbf16>
  return %8 : tensor<10x10xbf16>
}

// CHECK-LABEL: func @convert_floor_div
// CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %arg1) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_div(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xbf16>
  %1 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xbf16>
  %2 = mhlo.remainder %arg0, %arg1 : tensor<10x10xbf16>
  %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %4 = "mhlo.sign"(%arg1) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %5 = "mhlo.sign"(%2) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %7 = mhlo.and %3, %6 : tensor<10x10xi1>
  %8 = mhlo.subtract %arg0, %2 : tensor<10x10xbf16>
  %9 = mhlo.divide %8, %arg1 : tensor<10x10xbf16>
  %10 = mhlo.add %9, %1 : tensor<10x10xbf16>
  %11 = "mhlo.select"(%7, %10, %9) : (tensor<10x10xi1>, tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %12 = "mhlo.round_nearest_afz"(%11) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %13 = "mhlo.tuple"(%12) : (tensor<10x10xbf16>) -> tuple<tensor<10x10xbf16>>
  func.return %12 : tensor<10x10xbf16>
}

// CHECK-LABEL: func @convert_floor_div_cst
// CHECK: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<10x10xbf16>}> : () -> tensor<10x10xbf16>
// CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %[[CST2]]) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_div_cst(%arg0: tensor<10x10xbf16>) -> tensor<10x10xbf16> {
  %0 = mhlo.constant dense<2.000000e+00> : tensor<10x10xbf16>
  %1 = mhlo.constant dense<0.000000e+00> : tensor<10x10xbf16>
  %2 = mhlo.constant dense<1.000000e+00> : tensor<10x10xbf16>
  %3 = mhlo.constant dense<5.000000e-01> : tensor<10x10xbf16>
  %4 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xbf16>
  %5 = mhlo.remainder %arg0, %0 : tensor<10x10xbf16>
  %6 = "mhlo.compare"(%5, %1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %7 = "mhlo.sign"(%5) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %8 = "mhlo.compare"(%2, %7) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %9 = mhlo.and %6, %8 : tensor<10x10xi1>
  %10 = mhlo.subtract %arg0, %5 : tensor<10x10xbf16>
  %11 = mhlo.multiply %10, %3 : tensor<10x10xbf16>
  %12 = mhlo.add %11, %4 : tensor<10x10xbf16>
  %13 = "mhlo.select"(%9, %12, %11) : (tensor<10x10xi1>, tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %14 = "mhlo.round_nearest_afz"(%13) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %15 = "mhlo.tuple"(%14) : (tensor<10x10xbf16>) -> tuple<tensor<10x10xbf16>>
  func.return %14 : tensor<10x10xbf16>
}

// CHECK-LABEL: func @convert_floor_div_cst2
// CHECK: %[[CST2:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<10x10xbf16>}> : () -> tensor<10x10xbf16>
// CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %[[CST2]]) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_div_cst2(%arg0: tensor<10x10xbf16>) -> tensor<10x10xbf16> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<10x10xbf16>
  %1 = mhlo.constant dense<2.000000e+00> : tensor<10x10xbf16>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<10x10xbf16>
  %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xbf16>
  %4 = mhlo.remainder %arg0, %1 : tensor<10x10xbf16>
  %5 = "mhlo.compare"(%4, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %6 = "mhlo.sign"(%4) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %7 = "mhlo.compare"(%0, %6) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1>
  %8 = mhlo.and %5, %7 : tensor<10x10xi1>
  %9 = mhlo.subtract %arg0, %4 : tensor<10x10xbf16>
  %10 = mhlo.divide %9, %1 : tensor<10x10xbf16>
  %11 = mhlo.add %10, %3 : tensor<10x10xbf16>
  %12 = "mhlo.select"(%8, %11, %10) : (tensor<10x10xi1>, tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %13 = "mhlo.round_nearest_afz"(%12) : (tensor<10x10xbf16>) -> tensor<10x10xbf16>
  %14 = "mhlo.tuple"(%13) : (tensor<10x10xbf16>) -> tuple<tensor<10x10xbf16>>
  func.return %13 : tensor<10x10xbf16>
}

// CHECK-LABEL: func @convert_floor_div_broadcast_cst
// CHECK: %[[BCST:.*]] = "tf.BroadcastTo"{{.*}} : (tensor<8xf32>, tensor<2xi64>) -> tensor<10x8xf32>
// CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %[[BCST]]) : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32>
// CHECK: return %[[RESULT]]
// CHECK: }
func.func @convert_floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10x8xf32> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<10x8xf32>
  %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32>
  %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32>
  %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32>
  %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32>
  %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32>
  %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
  %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32>
  %9 = "mhlo.compare"(%0, %8) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1>
  %10 = mhlo.and %7, %9 : tensor<10x8xi1>
  %11 = mhlo.subtract %arg0, %6 : tensor<10x8xf32>
  %12 = mhlo.divide %11, %5 : tensor<10x8xf32>
  %13 = mhlo.add %12, %3 : tensor<10x8xf32>
  %14 = "mhlo.select"(%10, %13, %12) : (tensor<10x8xi1>, tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32>
  %15 = "mhlo.round_nearest_afz"(%14) : (tensor<10x8xf32>) -> tensor<10x8xf32>
  %16 = "mhlo.tuple"(%15) : (tensor<10x8xf32>) -> tuple<tensor<10x8xf32>>
  func.return %15 : tensor<10x8xf32>
}


// CHECK-LABEL:   func @convert_gather(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<147456xf16>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<192x256x1xi32>)
// CHECK:            %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<192x256xf16>
// CHECK:            return %[[VAL_0]]
// CHECK:         }
func.func @convert_gather(%arg0: tensor<147456xf16>, %arg1: tensor<192x256x1xi32>) -> tensor<192x256xf16> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [0],
      index_vector_dim = 2,
			start_index_map = [0],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<1> : tensor<1xi64>
  } : (tensor<147456xf16>, tensor<192x256x1xi32>) -> tensor<192x256xf16>
  func.return %0 : tensor<192x256xf16>
}

// CHECK-LABEL:   func @convert_gather_with_ui32indices(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<147456xf16>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<192x256x1xui32>)
// CHECK:            %[[INDICES:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<192x256x1xui32>) -> tensor<192x256x1xi64>
// CHECK:            %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[INDICES]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<192x256xf16>
// CHECK:            return %[[VAL_0]]
// CHECK:         }
func.func @convert_gather_with_ui32indices(%arg0: tensor<147456xf16>, %arg1: tensor<192x256x1xui32>) -> tensor<192x256xf16> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [0],
      index_vector_dim = 2,
			start_index_map = [0],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<1> : tensor<1xi64>
  } : (tensor<147456xf16>, tensor<192x256x1xui32>) -> tensor<192x256xf16>
  func.return %0 : tensor<192x256xf16>
}

// CHECK-LABEL:   func @convert_gather_nd(
// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<98x128xf32>,
// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<4x64xi32>)
// CHECK:           %[[VAL_2:.*]] = arith.constant dense<[4, 64, 1]> : tensor<3xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_2]]) : {{.*}} -> tensor<4x64x1xi32>
// CHECK:           %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_3]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<4x64x128xf32>
// CHECK:           return %[[VAL_4]]
// CHECK:         }
func.func @convert_gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> tensor<4x64x128xf32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [0],
      index_vector_dim = 2,
			offset_dims = [2],
			start_index_map = [0],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<[1, 128]> : tensor<2xi64>
  } : (tensor<98x128xf32>, tensor<4x64xi32>) -> tensor<4x64x128xf32>
  func.return %0 : tensor<4x64x128xf32>
}

// CHECK-LABEL:   func @convert_gather_transpose(
// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<128x256xf32>,
// CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<4x1xi32>) -> tensor<4x128xf32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"{{.*}}value = dense<[1, 0]> : tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : {{.*}} -> tensor<256x128xf32>
// CHECK:           %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<4x128xf32>
// CHECK:           return %[[VAL_4]]
// CHECK:         }
// Test the case when start_index_map isn't an iota what requires a transpose to
// convert it to tf.GatherNd.
func.func @convert_gather_transpose(%arg0: tensor<128x256xf32>, %arg1: tensor<4x1xi32>) -> tensor<4x128xf32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [1],
      index_vector_dim = 1,
			offset_dims = [1],
			start_index_map = [1],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<[128, 1]> : tensor<2xi64>
  } : (tensor<128x256xf32>, tensor<4x1xi32>) -> tensor<4x128xf32>
  func.return %0 : tensor<4x128xf32>
}

// CHECK-LABEL: func @convert_gather_offset(
// CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<1x20xi32>,
// CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<1x1xi32>) -> tensor<1x1xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Transpose"(%[[VAL_0]], %[[VAL_2]]) : (tensor<1x20xi32>, tensor<2xi64>) -> tensor<20x1xi32>
// CHECK:           %[[VAL_4:.*]] = "tf.GatherNd"(%[[VAL_3]], %[[VAL_1]]) <{bad_indices_policy = ""}> : (tensor<20x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_6:.*]] = "tf.Transpose"(%[[VAL_4]], %[[VAL_5]]) : (tensor<1x1xi32>, tensor<2xi64>) -> tensor<1x1xi32>
// CHECK:           return %[[VAL_6]] : tensor<1x1xi32>
// CHECK:         }
func.func @convert_gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32>) -> tensor<1x1xi32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [1],
      index_vector_dim = 1,
			offset_dims = [0],
			start_index_map = [1],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<1> : tensor<2xi64>
  } : (tensor<1x20xi32>, tensor<1x1xi32>) -> tensor<1x1xi32>
  func.return %0 : tensor<1x1xi32>
}

// CHECK-LABEL:   func @convert_gather_batching_dims(
// CHECK-SAME:                          %[[ARG_0:.*]]: tensor<2x3x128xf32>,
// CHECK-SAME:                          %[[ARG_1:.*]]: tensor<3x128x2x1xi32>)
// CHECK-DAG:         %[[CST:.*]] = arith.constant dense<[6, 128]> : tensor<2xi64>
// CHECK:             %[[VAL_0:.*]] = "tf.Reshape"(%[[ARG_0]], %[[CST]]) : (tensor<2x3x128xf32>, tensor<2xi64>) -> tensor<6x128xf32>
// CHECK-DAG:         %[[CST_0:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:             %[[VAL_1:.*]] = "tf.Transpose"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x128x2x1xi32>, tensor<4xi64>) -> tensor<2x3x128x1xi32>
// CHECK-DAG:         %[[CST_1:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64>
// CHECK:             %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST_1]]) : (tensor<2x3x128x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32>
// CHECK-DAG:         %[[CST_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:         %[[CST_3:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:         %[[CST_4:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:             %[[VAL_3:.*]] = "tf.Range"(%[[CST_2]], %[[CST_3]], %[[CST_4]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<6xi32>
// CHECK-DAG:         %[[CST_5:.*]] = "tf.Const"() <{value = dense<[6, 1, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:             %[[VAL_4:.*]] = "tf.Reshape"(%[[VAL_3]], %[[CST_5]]) : (tensor<6xi32>, tensor<3xi64>) -> tensor<6x1x1xi32>
// CHECK-DAG:         %[[CST_6:.*]] = "tf.Const"() <{value = dense<[6, 128, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:             %[[VAL_5:.*]] = "tf.BroadcastTo"(%[[VAL_4]], %[[CST_6]]) : (tensor<6x1x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32>
// CHECK-DAG:         %[[CST_7:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK:             %[[VAL_6:.*]] = "tf.ConcatV2"(%[[VAL_5]], %[[VAL_2]], %[[CST_7]]) : (tensor<6x128x1xi32>, tensor<6x128x1xi32>, tensor<i32>) -> tensor<6x128x2xi32>
// CHECK:             %[[VAL_7:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_6]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<6x128xf32>
// CHECK-DAG:         %[[CST_8:.*]] = arith.constant dense<[2, 3, 128]> : tensor<3xi64>
// CHECK:             %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_7]], %[[CST_8]]) : (tensor<6x128xf32>, tensor<3xi64>) -> tensor<2x3x128xf32>
// CHECK-DAG:         %[[CST_9:.*]] = "tf.Const"() <{value = dense<[1, 2, 0]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:             %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x3x128xf32>, tensor<3xi64>) -> tensor<3x128x2xf32>
// CHECK:             return %[[VAL_9]]
// CHECK:         }
func.func @convert_gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      index_vector_dim = 3,
      start_index_map = [2],
      operand_batching_dims = [0, 1],
      start_indices_batching_dims = [2, 0],
      collapsed_slice_dims = [2],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<1> : tensor<3xi64>
  } : (tensor<2x3x128xf32>, tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32>
  func.return %0 : tensor<3x128x2xf32>
}

// CHECK-LABEL: func @convert_gather_non_collapsed_index_dim(
// CHECK-SAME:                                      %[[ARG_0:.*]]: tensor<10x5xi32>,
// CHECK-SAME:                                      %[[ARG_1:.*]]: tensor<2x1xi32>) -> tensor<2x1x5xi32> {
// CHECK:           %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) <{bad_indices_policy = ""}> : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x5xi32>
// CHECK-DAG:       %[[CST:.*]] = arith.constant dense<[2, 1, 5]> : tensor<3xi64>
// CHECK:           %[[VAL_1:.*]] = "tf.Reshape"(%[[VAL_0]], %[[CST]]) : (tensor<2x5xi32>, tensor<3xi64>) -> tensor<2x1x5xi32>
// CHECK:           return %[[VAL_1]] : tensor<2x1x5xi32>
// CHECK:       }
func.func @convert_gather_non_collapsed_index_dim(%arg0: tensor<10x5xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x1x5xi32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      index_vector_dim = 1,
      offset_dims = [1, 2],
      start_index_map = [0],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<[1, 5]> : tensor<2xi64>
  } : (tensor<10x5xi32>, tensor<2x1xi32>) -> tensor<2x1x5xi32>
  func.return %0 : tensor<2x1x5xi32>
}

// CHECK-LABEL: func @convert_gather_indexed_dimension_slice(
// CHECK-SAME:                                      %[[ARG_0:.*]]: tensor<4x5x6xi32>,
// CHECK-SAME:                                      %[[ARG_1:.*]]: tensor<2x2xi32>) -> tensor<2x1x5x6xi32> {
// CHECK-DAG:       %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_0:.*]] = "tf.Transpose"(%[[ARG_0]], %[[CST]]) : (tensor<4x5x6xi32>, tensor<3xi64>) -> tensor<4x6x5xi32>
// CHECK-DAG:       %[[CST_0:.*]] = arith.constant dense<[2, 1, 2]> : tensor<3xi64>
// CHECK:           %[[VAL_1:.*]] = "tf.Reshape"(%[[ARG_1]], %[[CST_0]]) : (tensor<2x2xi32>, tensor<3xi64>) -> tensor<2x1x2xi32>
// CHECK-DAG:       %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[CST_2:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[CST_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_2:.*]] = "tf.Range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<6xi32>
// CHECK-DAG:       %[[CST_4:.*]] = "tf.Const"() <{value = dense<[1, 6, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_2]], %[[CST_4]]) : (tensor<6xi32>, tensor<3xi64>) -> tensor<1x6x1xi32>
// CHECK-DAG:       %[[CST_5:.*]] = "tf.Const"() <{value = dense<[1, 6, 1]> : tensor<3xi64>}> : () -> tensor<3xi64>
// CHECK:           %[[VAL_4:.*]] = "tf.BroadcastTo"(%[[VAL_3]], %[[CST_5]]) : (tensor<1x6x1xi32>, tensor<3xi64>) -> tensor<1x6x1xi32>
// CHECK-DAG:       %[[CST_6:.*]] = arith.constant dense<0> : tensor<i32>
// CHECK-DAG:       %[[CST_7:.*]] = arith.constant
// CHECK-SAME{LITERAL:  dense<[[0, 0], [0, 0], [1, 0]]> : tensor<3x2xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.PadV2"(%[[VAL_4]], %[[CST_7]], %[[CST_6]]) : (tensor<1x6x1xi32>, tensor<3x2xi64>, tensor<i32>) -> tensor<1x6x2xi32>
// CHECK:           %[[VAL_6:.*]] = "tf.Add"(%[[VAL_1]], %[[VAL_5]]) : (tensor<2x1x2xi32>, tensor<1x6x2xi32>) -> tensor<2x6x2xi32>
// CHECK:           %[[VAL_7:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_6]]) <{bad_indices_policy = ""}> : (tensor<4x6x5xi32>, tensor<2x6x2xi32>) -> tensor<2x6x5xi32>
// CHECK-DAG:       %[[CST_8:.*]] = arith.constant dense<[2, 1, 6, 5]> : tensor<4xi64>
// CHECK:           %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_7]], %[[CST_8]]) : (tensor<2x6x5xi32>, tensor<4xi64>) -> tensor<2x1x6x5xi32>
// CHECK-DAG:       %[[CST_9:.*]] = "tf.Const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64>
// CHECK:           %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x1x6x5xi32>, tensor<4xi64>) -> tensor<2x1x5x6xi32>
// CHECK:           return %[[VAL_9]] : tensor<2x1x5x6xi32>
// CHECK:       }
func.func @convert_gather_indexed_dimension_slice(%arg0: tensor<4x5x6xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x1x5x6xi32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      index_vector_dim = 1,
      offset_dims = [1, 2, 3],
      start_index_map = [0, 2],
    >,
    indices_are_sorted = false,
    slice_sizes = dense<[1, 5, 6]> : tensor<3xi64>
  } : (tensor<4x5x6xi32>, tensor<2x2xi32>) -> tensor<2x1x5x6xi32>
  func.return %0 : tensor<2x1x5x6xi32>
}

// CHECK-LABEL:   func @convert_gather_to_slice_batch_size_1(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<1x2944xi32>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<1x2xi32>)
// CHECK-DAG:         %[[CST:.*]] = "tf.Const"() <{value = dense<[0, 1440]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:         %[[CST_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:             %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0:.*]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK:             %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<1x2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK-DAG:         %[[CST_1:.*]] = "tf.Const"() <{value = dense<[1, 1504]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:             %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK:             %[[VAL_3:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_2]], %[[CST_1]]) : (tensor<1x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
// CHECK:            return %[[VAL_3]]
// CHECK:         }
func.func @convert_gather_to_slice_batch_size_1(%arg0: tensor<1x2944xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x1504xi32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      offset_dims = [1],
      collapsed_slice_dims = [0],
      start_index_map = [0, 1],
      index_vector_dim = 1,
    >,
    indices_are_sorted = true,
    slice_sizes = dense<[1, 1504]> : tensor<2xi64>
  } : (tensor<1x2944xi32>, tensor<1x2xi32>) -> tensor<1x1504xi32>
  func.return %0 : tensor<1x1504xi32>
}

// CHECK-LABEL:   func @convert_gather_slice_dynamic_indices(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<256000x1024xi8>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<?x?x1xi32>) -> tensor<?x?x1024xi8> {
// CHECK:            %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) <{bad_indices_policy = ""}> : (tensor<256000x1024xi8>, tensor<?x?x1xi32>) -> tensor<?x?x1024xi8>
// CHECK:            return %[[VAL_0]] : tensor<?x?x1024xi8>
// CHECK:         }
func.func @convert_gather_slice_dynamic_indices(%arg0: tensor<256000x1024xi8>, %arg1: tensor<?x?x1xi32>) -> tensor<?x?x1024xi8> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      offset_dims = [2],
      collapsed_slice_dims = [0],
      start_index_map = [0],
      index_vector_dim = 2
    >,
    slice_sizes = dense<[1, 1024]> : tensor<2xi64>
  } : (tensor<256000x1024xi8>, tensor<?x?x1xi32>) -> tensor<?x?x1024xi8>
  func.return %0 : tensor<?x?x1024xi8>
}

// CHECK-LABEL:   func @convert_gather_scalar_dynamic_indices(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<256000xf32>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<?x?x1xi32>) -> tensor<?x?xf32> {
// CHECK:            %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) <{bad_indices_policy = ""}> : (tensor<256000xf32>, tensor<?x?x1xi32>) -> tensor<?x?xf32>
// CHECK:            return %[[VAL_0]] : tensor<?x?xf32>
// CHECK:         }
func.func @convert_gather_scalar_dynamic_indices(%arg0: tensor<256000xf32>, %arg1: tensor<?x?x1xi32>) -> tensor<?x?xf32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      collapsed_slice_dims = [0],
      start_index_map = [0],
      index_vector_dim = 2
    >,
    slice_sizes = dense<1> : tensor<1xi64>
  } : (tensor<256000xf32>, tensor<?x?x1xi32>) -> tensor<?x?xf32>
  func.return %0 : tensor<?x?xf32>
}

// CHECK-LABEL:   func @convert_gather_to_slice(
// CHECK-SAME:                         %[[ARG_0:.*]]: tensor<3x2944xi32>,
// CHECK-SAME:                         %[[ARG_1:.*]]: tensor<3x2xi32>)
// CHECK-DAG:        %[[CST:.*]] = "tf.Const"() <{value = dense<[2, 1440]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:        %[[CST_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:            %[[VAL_0:.*]] = "tf.Maximum"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
// CHECK:            %[[VAL_1:.*]] = "tf.Minimum"(%[[VAL_0]], %[[CST]]) : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
// CHECK-DAG:        %[[CST_1:.*]] = "tf.Const"() <{value = dense<[1, 1504]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:        %[[CST_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:        %[[CST_3:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:            %[[VAL_2:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_2]], %[[CST_3]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK:            %[[VAL_3:.*]] = "tf.Squeeze"(%[[VAL_2]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK:            %[[VAL_4:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_3]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
// CHECK-DAG:        %[[CST_4:.*]] = "tf.Const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:        %[[CST_5:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:            %[[VAL_5:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_4]], %[[CST_5]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK:            %[[VAL_6:.*]] = "tf.Squeeze"(%[[VAL_5]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK:            %[[VAL_7:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_6]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
// CHECK-DAG:        %[[CST_6:.*]] = "tf.Const"() <{value = dense<[2, 0]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK-DAG:        %[[CST_7:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK:            %[[VAL_8:.*]] = "tf.Slice"(%[[VAL_1]], %[[CST_6]], %[[CST_7]]) : (tensor<3x2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK:            %[[VAL_9:.*]] = "tf.Squeeze"(%[[VAL_8]]) <{squeeze_dims = [0]}> : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK:            %[[VAL_10:.*]] = "tf.Slice"(%[[ARG_0]], %[[VAL_9]], %[[CST_1]]) : (tensor<3x2944xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x1504xi32>
// CHECK-DAG:        %[[CST_8:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:            %[[VAL_11:.*]] = "tf.ConcatV2"(%[[VAL_4]], %[[VAL_7]], %[[VAL_10]], %[[CST_8]]) : (tensor<1x1504xi32>, tensor<1x1504xi32>, tensor<1x1504xi32>, tensor<i32>) -> tensor<3x1504xi32>
// CHECK:            return %[[VAL_11]]
// CHECK:         }
func.func @convert_gather_to_slice(%arg0: tensor<3x2944xi32>, %arg1: tensor<3x2xi32>) -> tensor<3x1504xi32> {
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      offset_dims = [1],
      collapsed_slice_dims = [0],
      start_index_map = [0, 1],
      index_vector_dim = 1,
    >,
    indices_are_sorted = true,
    slice_sizes = dense<[1, 1504]> : tensor<2xi64>
  } : (tensor<3x2944xi32>, tensor<3x2xi32>) -> tensor<3x1504xi32>
  func.return %0 : tensor<3x1504xi32>
}

// CHECK-LABEL:   func @convert_gather_to_slice_dynamic_error
func.func @convert_gather_to_slice_dynamic_error(%arg0: tensor<3x?xi32>, %arg1: tensor<3x2xi32>) -> tensor<3x1504xi32> {
  // expected-error @+1 {{Dynamic shaped operand is not supported.}}
  %0 = "mhlo.gather"(%arg0, %arg1) {
    dimension_numbers = #mhlo.gather<
      offset_dims = [1],
      collapsed_slice_dims = [0],
      start_index_map = [0, 1],
      index_vector_dim = 1,
    >,
    indices_are_sorted = true,
    slice_sizes = dense<[1, 1504]> : tensor<2xi64>
  } : (tensor<3x?xi32>, tensor<3x2xi32>) -> tensor<3x1504xi32>
  func.return %0 : tensor<3x1504xi32>
}

// CHECK-LABEL: func @convert_dynamic_slice(
// CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<7x3xf32>,
// CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<i32>,
// CHECK-SAME:                                      %[[VAL_2:.*]]: tensor<i32>) -> tensor<4x2xf32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) <{Truncate = false}> : (tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_7:.*]] = "tf.Maximum"(%[[VAL_6]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) <{Truncate = false}> : (tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_9:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_10:.*]] = "tf.Minimum"(%[[VAL_8]], %[[VAL_9]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_11:.*]] = "tf.Maximum"(%[[VAL_10]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
// CHECK:           %[[VAL_13:.*]] = "tf.Const"() <{value = dense<[4, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_12]], %[[VAL_13]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
// CHECK:           return %[[VAL_14]] : tensor<4x2xf32>
// CHECK:         }
func.func @convert_dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<4x2xf32> {
  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<i32>, tensor<i32>) -> tensor<4x2xf32>
  func.return %0 : tensor<4x2xf32>
}

// CHECK-LABEL: func @convert_dynamic_slice_ui32(
// CHECK-SAME:                                           %[[VAL_0:.*]]: tensor<7x3xf32>,
// CHECK-SAME:                                           %[[VAL_1:.*]]: tensor<ui32>,
// CHECK-SAME:                                           %[[VAL_2:.*]]: tensor<ui32>) -> tensor<4x2xf32> {
// CHECK:           %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_4:.*]] = "tf.Cast"(%[[VAL_1]]) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
// CHECK:           %[[VAL_5:.*]] = "tf.Const"() <{value = dense<3> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_6:.*]] = "tf.Minimum"(%[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_7:.*]] = "tf.Maximum"(%[[VAL_6]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_8:.*]] = "tf.Cast"(%[[VAL_2]]) <{Truncate = false}> : (tensor<ui32>) -> tensor<i32>
// CHECK:           %[[VAL_9:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_10:.*]] = "tf.Minimum"(%[[VAL_8]], %[[VAL_9]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_11:.*]] = "tf.Maximum"(%[[VAL_10]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:           %[[VAL_12:.*]] = "tf.Pack"(%[[VAL_7]], %[[VAL_11]]) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
// CHECK:           %[[VAL_13:.*]] = "tf.Const"() <{value = dense<[4, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_14:.*]] = "tf.Slice"(%[[VAL_0]], %[[VAL_12]], %[[VAL_13]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32>
// CHECK:           return %[[VAL_14]] : tensor<4x2xf32>
// CHECK:         }
func.func @convert_dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor<ui32>, %arg2: tensor<ui32>) -> tensor<4x2xf32> {
  %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor<ui32>, tensor<ui32>) -> tensor<4x2xf32>
  func.return %0 : tensor<4x2xf32>
}

// CHECK-LABEL:   func.func @convert_scatter_update(
// CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<4xi32>,
// CHECK-SAME:                                      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK:             mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:           }) : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK:           return %[[VAL_3]] : tensor<20x6xf32>
// CHECK:         }
func.func @convert_scatter_update(%arg0: tensor<20x6xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    "mhlo.return"(%arg4) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [1],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<20x6xf32>, tensor<4xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
  func.return %0 : tensor<20x6xf32>
}

// CHECK-LABEL:   func.func @convert_scatter_update_with_non_trailing_update_window_dims(
// CHECK-SAME:                                                                           %[[VAL_0:.*]]: tensor<5x10xf32>,
// CHECK-SAME:                                                                           %[[VAL_1:.*]]: tensor<3x1xi32>,
// CHECK-SAME:                                                                           %[[VAL_2:.*]]: tensor<10x3xf32>) -> tensor<5x10xf32> {
// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:         indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK:             mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:           }) : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32>
// CHECK:           return %[[VAL_3]] : tensor<5x10xf32>
// CHECK:         }
func.func @convert_scatter_update_with_non_trailing_update_window_dims(
  %arg0: tensor<5x10xf32>,
  %arg1: tensor<3x1xi32>,
  %arg2: tensor<10x3xf32>) -> tensor<5x10xf32>
{
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    "mhlo.return"(%arg4) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [0],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<5x10xf32>, tensor<3x1xi32>, tensor<10x3xf32>) -> tensor<5x10xf32>
  func.return %0 : tensor<5x10xf32>
}

// CHECK-LABEL:   func.func @convert_scatter_update_to_non_trailing_operand_dimensions(
// CHECK-SAME:                                                                         %[[VAL_0:.*]]: tensor<5x4x3x7xf32>,
// CHECK-SAME:                                                                         %[[VAL_1:.*]]: tensor<2x2xi32>,
// CHECK-SAME:                                                                         %[[VAL_2:.*]]: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32> {
// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1, 2], inserted_window_dims = [1, 3], scatter_dims_to_operand_dims = [1, 3], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK:             mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:           }) : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32>
// CHECK:           return %[[VAL_3]] : tensor<5x4x3x7xf32>
// CHECK:         }
func.func @convert_scatter_update_to_non_trailing_operand_dimensions(
  %arg0: tensor<5x4x3x7xf32>,
  %arg1: tensor<2x2xi32>,
  %arg2: tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32>
{
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
      "mhlo.return"(%arg4) : (tensor<f32>) -> ()
    }) {
      scatter_dimension_numbers = #mhlo.scatter<
        update_window_dims = [1, 2],
        inserted_window_dims = [1, 3],
        scatter_dims_to_operand_dims = [1, 3],
        index_vector_dim = 1>,
      indices_are_sorted = false,
      unique_indices = false} : (tensor<5x4x3x7xf32>, tensor<2x2xi32>, tensor<2x5x3xf32>) -> tensor<5x4x3x7xf32>
  func.return %0 : tensor<5x4x3x7xf32>
}

// CHECK-LABEL:   func.func @convert_scatter_update_reshape_indices_and_updates(
// CHECK-SAME:                                                                  %[[VAL_0:.*]]: tensor<16x1504xf32>,
// CHECK-SAME:                                                                  %[[VAL_1:.*]]: tensor<1xi32>,
// CHECK-SAME:                                                                  %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16x1504xf32> {
// CHECK:           %[[VAL_3:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true
// CHECK-SAME:      }> ({
// CHECK:           ^bb0(%[[VAL_4:.*]]: tensor<f32>, %[[VAL_5:.*]]: tensor<f32>):
// CHECK:             mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:           }) : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32>
// CHECK:           return %[[VAL_3]] : tensor<16x1504xf32>
// CHECK:         }
func.func @convert_scatter_update_reshape_indices_and_updates(
  %arg0: tensor<16x1504xf32>,
  %arg1: tensor<1xi32>,
  %arg2: tensor<16xf32>) -> tensor<16x1504xf32>
{
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
  "mhlo.return"(%arg4) : (tensor<f32>) -> ()
}) {
  indices_are_sorted = true,
  scatter_dimension_numbers = #mhlo.scatter<
    update_window_dims = [0],
    inserted_window_dims = [1],
    scatter_dims_to_operand_dims = [1]>,
    unique_indices = true} : (tensor<16x1504xf32>, tensor<1xi32>, tensor<16xf32>) -> tensor<16x1504xf32>
  func.return %0 : tensor<16x1504xf32>
}

// CHECK-LABEL:  func.func @convert_scatter_add(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
// CHECK:      %[[VAL_5:.*]] = mhlo.add %[[VAL_3]], %[[VAL_4]] : tensor<f32>
// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:    }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
// CHECK:  }
func.func @convert_scatter_add(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %2 = mhlo.add %arg3, %arg4 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [1],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1,
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
  func.return %0 : tensor<20x6xf32>
}

// CHECK-LABEL:  func.func @convert_scatter_max(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
// CHECK:      %[[VAL_5:.*]] = mhlo.maximum %[[VAL_3]], %[[VAL_4]] : tensor<f32>
// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:    }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
// CHECK:  }
func.func @convert_scatter_max(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %2 = mhlo.maximum %arg3, %arg4 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [1],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1,
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
  func.return %0 : tensor<20x6xf32>
}

// CHECK-LABEL:  func.func @convert_scatter_min(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
// CHECK:      %[[VAL_5:.*]] = mhlo.minimum %[[VAL_3]], %[[VAL_4]] : tensor<f32>
// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:    }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
// CHECK:  }
func.func @convert_scatter_min(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %2 = mhlo.minimum %arg3, %arg4 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [1],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1,
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
  func.return %0 : tensor<20x6xf32>
}

// CHECK-LABEL:  func.func @convert_scatter_sub(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<20x6xf32>,
// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x1xi32>,
// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x6xf32>) -> tensor<20x6xf32> {
// CHECK:    %[[VAL_6:.*]] = "mhlo.scatter"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{
// CHECK-SAME:          indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false
// CHECK-SAME:      }> ({
// CHECK:    ^bb0(%[[VAL_3:.*]]: tensor<f32>, %[[VAL_4:.*]]: tensor<f32>):
// CHECK:      %[[VAL_5:.*]] = mhlo.subtract %[[VAL_3]], %[[VAL_4]] : tensor<f32>
// CHECK:      mhlo.return %[[VAL_5]] : tensor<f32>
// CHECK:    }) : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
// CHECK:    return %[[VAL_6]] : tensor<20x6xf32>
// CHECK:  }
func.func @convert_scatter_sub(%arg0: tensor<20x6xf32>, %arg1: tensor<4x1xi32>, %arg2: tensor<4x6xf32>) -> tensor<20x6xf32> {
  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
    %2 = mhlo.subtract %arg3, %arg4 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {
    scatter_dimension_numbers = #mhlo.scatter<
      update_window_dims = [1],
      inserted_window_dims = [0],
      scatter_dims_to_operand_dims = [0],
      index_vector_dim = 1,
    >,
    indices_are_sorted = false,
    unique_indices = false} : (tensor<20x6xf32>, tensor<4x1xi32>, tensor<4x6xf32>) -> tensor<20x6xf32>
  func.return %0 : tensor<20x6xf32>
}

// CHECK-LABEL:   func @convert_pytorch_argmax
func.func @convert_pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> {
  %0 = mhlo.constant dense<0> : tensor<i32>
  %1 = mhlo.constant dense<-2147483648> : tensor<i32>
  %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32>
  %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32>
  %4:2 = mhlo.reduce(%arg0 init: %1), (%3 init: %0) across dimensions = [1] : (tensor<1x9xi32>, tensor<1x9xi32>, tensor<i32>, tensor<i32>) -> (tensor<1xi32>, tensor<1xi32>)
    reducer(%arg1: tensor<i32>, %arg3: tensor<i32>) (%arg2: tensor<i32>, %arg4: tensor<i32>)  {
    %6 = mhlo.compare  GE, %arg1, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %7 = mhlo.select %6, %arg1, %arg3 : tensor<i1>, tensor<i32>
    %8 = mhlo.compare  EQ, %arg1, %arg3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %9 = mhlo.minimum %arg2, %arg4 : tensor<i32>
    %10 = mhlo.select %6, %arg2, %arg4 : tensor<i1>, tensor<i32>
    %11 = mhlo.select %8, %9, %10 : tensor<i1>, tensor<i32>
    mhlo.return %7, %11 : tensor<i32>, tensor<i32>
  }
  func.return %4#1 : tensor<1xi32>

  // CHECK-DAG:  %cst = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
  // CHECK:  %cst_0 = "tf.Const"() <{value = dense<-2147483648> : tensor<i32>}> : () -> tensor<i32>
  // CHECK-DAG:  %cst_1 = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
  // CHECK-DAG:  %cst_2 = "tf.Const"() <{value = dense<9> : tensor<i32>}> : () -> tensor<i32>
  // CHECK:  %cst_3 = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
  // CHECK:  %0 = "tf.Range"(%cst_1, %cst_2, %cst_3) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<9xi32>
  // CHECK:  %cst_4 = arith.constant dense<[1, 9]> : tensor<2xi64>
  // CHECK:  %1 = "tf.Reshape"(%0, %cst_4) : (tensor<9xi32>, tensor<2xi64>) -> tensor<1x9xi32>
  // CHECK:  %cst_5 = arith.constant dense<1> : tensor<1xi32>
  // CHECK:  %2 = "tf.Max"(%arg0, %cst_5) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
  // CHECK:  %3 = "tf.ArgMax"(%arg0, %cst_5) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32>
  // CHECK:  return %3 : tensor<1xi32>
}

// CHECK-LABEL:   func @convert_argmax(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
// CHECK:           %[[VAL_9:.*]] = arith.constant dense<2> : tensor<1xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<4x32xf32>
// CHECK:           %[[VAL_11:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<4x32xi32>
// CHECK:           return %[[VAL_10]], %[[VAL_11]]
// CHECK:         }
func.func @convert_argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32>
  %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32>
  %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>):
    %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %8 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %9 = mhlo.or %7, %8 : tensor<i1>
    %10 = "mhlo.select"(%9, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %11 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %12 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %13 = mhlo.and %11, %12 : tensor<i1>
    %14 = mhlo.or %9, %13 : tensor<i1>
    %15 = "mhlo.select"(%14, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%10, %15) : (tensor<f32>, tensor<i32>) -> ()
  }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x32x256xf32>, tensor<4x32x256xi32>, tensor<f32>, tensor<i32>) -> (tensor<4x32xf32>, tensor<4x32xi32>)
  func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32>
}

// CHECK-LABEL: func @convert_argmax_constant(
// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<2> : tensor<1xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_4]]) : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xi32>
// CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
// CHECK:         }
func.func @convert_argmax_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %3 = mhlo.constant dense<[[[0, 1, 2, 3], [0, 1, 2, 3]], [[0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>
  %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>):
    %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %8 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %9 = mhlo.or %7, %8 : tensor<i1>
    %10 = "mhlo.select"(%9, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %11 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %12 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %13 = mhlo.and %11, %12 : tensor<i1>
    %14 = mhlo.or %9, %13 : tensor<i1>
    %15 = "mhlo.select"(%14, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%10, %15) : (tensor<f32>, tensor<i32>) -> ()
  }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<2x2x4xf32>, tensor<2x2x4xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
  func.return %4#0, %4#1 : tensor<2x2xf32>, tensor<2x2xi32>
}

// CHECK-LABEL:   func @convert_argmax_constant_non_z_axis(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xi32>) {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>}> : () -> tensor<4x4xi32>
// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<4xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_4]]) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<4xi32>
// CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<4xf32>, tensor<4xi32>
// CHECK:         }
func.func @convert_argmax_constant_non_z_axis(%arg0: tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xi32>) {
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %3 = mhlo.constant dense<[[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]> : tensor<4x4xi32>
  %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>):
    %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %8 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %9 = mhlo.or %7, %8 : tensor<i1>
    %10 = "mhlo.select"(%9, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %11 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %12 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %13 = mhlo.and %11, %12 : tensor<i1>
    %14 = mhlo.or %9, %13 : tensor<i1>
    %15 = "mhlo.select"(%14, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%10, %15) : (tensor<f32>, tensor<i32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4x4xi32>, tensor<f32>, tensor<i32>) -> (tensor<4xf32>, tensor<4xi32>)
  func.return %4#0, %4#1 : tensor<4xf32>, tensor<4xi32>
}

// CHECK-LABEL:   func.func @convert_argmax_bool(
// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<2xi1>) -> tensor<i32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_4:.*]] = "tf.Range"(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           %[[VAL_8:.*]] = "tf.Any"(%[[VAL_0]], %[[VAL_7]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor<i1>
// CHECK:           %[[VAL_9:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_7]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor<i32>
// CHECK:           return %[[VAL_9]] : tensor<i32>
// CHECK:         }

func.func @convert_argmax_bool(%arg0: tensor<2xi1>) -> tensor<i32> {
  %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32>
  %1 = mhlo.constant dense<false> : tensor<i1>
  %2 = mhlo.constant dense<0> : tensor<i32>
  %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor<i1>, tensor<i32>) -> (tensor<i1>, tensor<i32>)
    reducer(%arg1: tensor<i1>, %arg3: tensor<i1>) (%arg2: tensor<i32>, %arg4: tensor<i32>)  {
    %4 = mhlo.compare  GT, %arg1, %arg3 : (tensor<i1>, tensor<i1>) -> tensor<i1>
    %5 = mhlo.select %4, %arg1, %arg3 : tensor<i1>, tensor<i1>
    %6 = mhlo.compare  EQ, %arg1, %arg3 : (tensor<i1>, tensor<i1>) -> tensor<i1>
    %7 = mhlo.compare  LT, %arg2, %arg4 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %8 = mhlo.and %6, %7 : tensor<i1>
    %9 = mhlo.or %4, %8 : tensor<i1>
    %10 = mhlo.select %9, %arg2, %arg4 : tensor<i1>, tensor<i32>
    mhlo.return %5, %10 : tensor<i1>, tensor<i32>
  }
  return %3#1 : tensor<i32>
}

// CHECK-LABEL:   func @convert_argmin(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
// CHECK:           %[[VAL_9:.*]] = arith.constant dense<2> : tensor<1xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<4x32xf32>
// CHECK:           %[[VAL_11:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<4x32xi32>
// CHECK:           return %[[VAL_10]], %[[VAL_11]]
// CHECK:         }
func.func @convert_argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32xi32>) {
  %0 = mhlo.constant dense<0x7F800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32>
  %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32>
  %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>):
    %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %8 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %9 = mhlo.or %7, %8 : tensor<i1>
    %10 = "mhlo.select"(%9, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %11 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %12 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %13 = mhlo.and %11, %12 : tensor<i1>
    %14 = mhlo.or %9, %13 : tensor<i1>
    %15 = "mhlo.select"(%14, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%10, %15) : (tensor<f32>, tensor<i32>) -> ()
  }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x32x256xf32>, tensor<4x32x256xi32>, tensor<f32>, tensor<i32>) -> (tensor<4x32xf32>, tensor<4x32xi32>)
  func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32>
}

// CHECK-LABEL:   func @convert_argmin_i16(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<2xi16>) -> (tensor<i16>, tensor<i32>) {
// CHECK:           %[[VAL_9:.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : {{.*}} -> tensor<i16>
// CHECK:           %[[VAL_11:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_9]]) : {{.*}} -> tensor<i32>
// CHECK:           return %[[VAL_10]], %[[VAL_11]]
// CHECK:         }
func.func @convert_argmin_i16(%arg0: tensor<2xi16>) -> (tensor<i16>, tensor<i32>) {
  %0 = mhlo.constant dense<false> : tensor<i1>
  %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32>
  %2 = mhlo.constant dense<32767> : tensor<i16>
  %3 = mhlo.constant dense<0> : tensor<i32>
  %4:2 = "mhlo.reduce"(%arg0, %1, %2, %3) ({
  ^bb0(%arg1: tensor<i16>, %arg2: tensor<i32>, %arg3: tensor<i16>, %arg4: tensor<i32>):
    %11 = mhlo.constant dense<false> : tensor<i1>
    %12 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i16>, tensor<i16>) -> tensor<i1>
    %13 = "mhlo.select"(%12, %arg1, %arg3) : (tensor<i1>, tensor<i16>, tensor<i16>) -> tensor<i16>
    %14 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<i16>, tensor<i16>) -> tensor<i1>
    %15 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %16 = mhlo.and %14, %15 : tensor<i1>
    %17 = mhlo.or %12, %16 : tensor<i1>
    %18 = "mhlo.select"(%17, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%13, %18) : (tensor<i16>, tensor<i32>) -> ()
  }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi16>, tensor<2xi32>, tensor<i16>, tensor<i32>) -> (tensor<i16>, tensor<i32>)
  func.return %4#0, %4#1 : tensor<i16>, tensor<i32>
}


// CHECK-LABEL: func @convert_argmin_constant(
// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0x7F800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<{{\[\[}}[0, 1, 2, 3], [0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>}> : () -> tensor<2x2x4xi32>
// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<2> : tensor<1xi32>
// CHECK:           %[[VAL_5:.*]] = "tf.Min"(%[[VAL_0]], %[[VAL_4]]) <{keep_dims = false}> : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xf32>
// CHECK:           %[[VAL_6:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_4]]) : (tensor<2x2x4xf32>, tensor<1xi32>) -> tensor<2x2xi32>
// CHECK:           return %[[VAL_5]], %[[VAL_6]] : tensor<2x2xf32>, tensor<2x2xi32>
// CHECK:         }
func.func @convert_argmin_constant(%arg0: tensor<2x2x4xf32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) {
  %0 = mhlo.constant dense<0x7F800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %3 = mhlo.constant dense<[[[0, 1, 2, 3], [0, 1, 2, 3]], [[0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x2x4xi32>
  %4:2 = "mhlo.reduce"(%arg0, %3, %0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>):
    %7 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %8 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %9 = mhlo.or %7, %8 : tensor<i1>
    %10 = "mhlo.select"(%9, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %11 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %12 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %13 = mhlo.and %11, %12 : tensor<i1>
    %14 = mhlo.or %9, %13 : tensor<i1>
    %15 = "mhlo.select"(%14, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%10, %15) : (tensor<f32>, tensor<i32>) -> ()
  }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<2x2x4xf32>, tensor<2x2x4xi32>, tensor<f32>, tensor<i32>) -> (tensor<2x2xf32>, tensor<2x2xi32>)
  func.return %4#0, %4#1 : tensor<2x2xf32>, tensor<2x2xi32>
}

// CHECK-LABEL:   func.func @convert_argmin_bool(
// CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<2xi1>) -> tensor<i32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_4:.*]] = "tf.Range"(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
// CHECK-DAG:       %[[VAL_6:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           %[[VAL_8:.*]] = "tf.All"(%[[VAL_0]], %[[VAL_7]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor<i1>
// CHECK:           %[[VAL_9:.*]] = "tf.ArgMin"(%[[VAL_0]], %[[VAL_7]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor<i32>
// CHECK:           return %[[VAL_9]] : tensor<i32>
// CHECK:         }
func.func @convert_argmin_bool(%arg0: tensor<2xi1>) -> tensor<i32> {
  %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32>
  %1 = mhlo.constant dense<false> : tensor<i1>
  %2 = mhlo.constant dense<0> : tensor<i32>
  %3:2 = mhlo.reduce(%arg0 init: %1), (%0 init: %2) across dimensions = [0] : (tensor<2xi1>, tensor<2xi32>, tensor<i1>, tensor<i32>) -> (tensor<i1>, tensor<i32>)
    reducer(%arg1: tensor<i1>, %arg3: tensor<i1>) (%arg2: tensor<i32>, %arg4: tensor<i32>)  {
    %4 = mhlo.compare  LT, %arg1, %arg3 : (tensor<i1>, tensor<i1>) -> tensor<i1>
    %5 = mhlo.select %4, %arg1, %arg3 : tensor<i1>, tensor<i1>
    %6 = mhlo.compare  EQ, %arg1, %arg3 : (tensor<i1>, tensor<i1>) -> tensor<i1>
    %7 = mhlo.compare  LT, %arg2, %arg4 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %8 = mhlo.and %6, %7 : tensor<i1>
    %9 = mhlo.or %4, %8 : tensor<i1>
    %10 = mhlo.select %9, %arg2, %arg4 : tensor<i1>, tensor<i32>
    mhlo.return %5, %10 : tensor<i1>, tensor<i32>
  }
  return %3#1 : tensor<i32>
}

// CHECK-LABEL:   func @convert_argmax_with_reshaped_iota(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0xFF800000> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_4:.*]] = "tf.Const"() <{value = dense<32> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_5:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_6:.*]] = "tf.Range"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32xi32>
// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<[1, 32, 1]> : tensor<3xi64>
// CHECK:           %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_6]], %[[VAL_7]]) : (tensor<32xi32>, tensor<3xi64>) -> tensor<1x32x1xi32>
// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant dense<1> : tensor<1xi32>
// CHECK:           %[[VAL_10:.*]] = "tf.Max"(%[[VAL_0]], %[[VAL_9]]) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32>
// CHECK:           %[[VAL_11:.*]] = "tf.ArgMax"(%[[VAL_0]], %[[VAL_9]]) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK:           return %[[VAL_10]], %[[VAL_11]] : tensor<1x1xf32>, tensor<1x1xi32>
// CHECK:         }
func.func @convert_argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf32>, tensor<1x1xi32>) {
  %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
  %1 = mhlo.constant dense<0> : tensor<i32>
  %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32>
  %3 = "mhlo.reshape"(%2) : (tensor<32xi32>) -> tensor<1x32x1xi32>
  %4:2 = mhlo.reduce(%arg0 init: %0), (%3 init: %1) across dimensions = [1] : (tensor<1x32x1xf32>, tensor<1x32x1xi32>, tensor<f32>, tensor<i32>) -> (tensor<1x1xf32>, tensor<1x1xi32>)
   reducer(%arg1: tensor<f32>, %arg3: tensor<f32>) (%arg2: tensor<i32>, %arg4: tensor<i32>)  {
    %5 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %6 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo<comparison_direction NE>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %7 = mhlo.or %5, %6 : tensor<i1>
    %8 = "mhlo.select"(%7, %arg1, %arg3) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
    %9 = "mhlo.compare"(%arg1, %arg3) {comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %10 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %11 = mhlo.and %9, %10 : tensor<i1>
    %12 = mhlo.or %7, %11 : tensor<i1>
    %13 = "mhlo.select"(%12, %arg2, %arg4) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    "mhlo.return"(%8, %13) : (tensor<f32>, tensor<i32>) -> ()
  }
  func.return %4#0, %4#1 : tensor<1x1xf32>, tensor<1x1xi32>
}

// CHECK-LABEL:   func @convert_not(
// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> {
// CHECK:           %[[VAL_1:.*]] = "tf.LogicalNot"(%[[VAL_0]]) : {{.*}} -> tensor<5x3x1xi1>
// CHECK:           return %[[VAL_1]] : tensor<5x3x1xi1>
// CHECK:         }
func.func @convert_not(%arg0: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> {
  %0 = "mhlo.not"(%arg0): (tensor<5x3x1xi1>) -> (tensor<5x3x1xi1>)
  func.return %0 : tensor<5x3x1xi1>
}

// CHECK-LABEL:   func @convert_not_i8(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i8>}> : () -> tensor<i8>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi8>, tensor<i8>) -> tensor<7x9x11xi8>
// CHECK:           return %[[RES]] : tensor<7x9x11xi8>
// CHECK:         }
func.func @convert_not_i8(%arg0: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi8>) -> (tensor<7x9x11xi8>)
  func.return %0 : tensor<7x9x11xi8>
}

// CHECK-LABEL:   func @convert_not_i16(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i16>}> : () -> tensor<i16>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi16>, tensor<i16>) -> tensor<7x9x11xi16>
// CHECK:           return %[[RES]] : tensor<7x9x11xi16>
// CHECK:         }
func.func @convert_not_i16(%arg0: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi16>) -> (tensor<7x9x11xi16>)
  func.return %0 : tensor<7x9x11xi16>
}

// CHECK-LABEL:   func @convert_not_i32(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi32>, tensor<i32>) -> tensor<7x9x11xi32>
// CHECK:           return %[[RES]] : tensor<7x9x11xi32>
// CHECK:         }
func.func @convert_not_i32(%arg0: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi32>) -> (tensor<7x9x11xi32>)
  func.return %0 : tensor<7x9x11xi32>
}

// CHECK-LABEL:   func @convert_not_i64(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xi64>) -> tensor<7x9x11xi64> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<-1> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi64>, tensor<i64>) -> tensor<7x9x11xi64>
// CHECK:           return %[[RES]] : tensor<7x9x11xi64>
// CHECK:         }
func.func @convert_not_i64(%arg0: tensor<7x9x11xi64>) -> tensor<7x9x11xi64> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi64>) -> (tensor<7x9x11xi64>)
  func.return %0 : tensor<7x9x11xi64>
}

// CHECK-LABEL:   func @convert_not_ui8(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<255> : tensor<ui8>}> : () -> tensor<ui8>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui8>, tensor<ui8>) -> tensor<7x9x11xui8>
// CHECK:           return %[[RES]] : tensor<7x9x11xui8>
// CHECK:         }
func.func @convert_not_ui8(%arg0: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui8>) -> (tensor<7x9x11xui8>)
  func.return %0 : tensor<7x9x11xui8>
}

// CHECK-LABEL:   func @convert_not_ui16(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<65535> : tensor<ui16>}> : () -> tensor<ui16>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui16>, tensor<ui16>) -> tensor<7x9x11xui16>
// CHECK:           return %[[RES]] : tensor<7x9x11xui16>
// CHECK:         }
func.func @convert_not_ui16(%arg0: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui16>) -> (tensor<7x9x11xui16>)
  func.return %0 : tensor<7x9x11xui16>
}

// CHECK-LABEL:   func @convert_not_ui32(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<4294967295> : tensor<ui32>}> : () -> tensor<ui32>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui32>, tensor<ui32>) -> tensor<7x9x11xui32>
// CHECK:           return %[[RES]] : tensor<7x9x11xui32>
// CHECK:         }
func.func @convert_not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui32>) -> (tensor<7x9x11xui32>)
  func.return %0 : tensor<7x9x11xui32>
}

// CHECK-LABEL:   func @convert_not_ui64(
// CHECK-SAME:                      %[[ARG:.*]]: tensor<7x9x11xui64>) -> tensor<7x9x11xui64> {
// CHECK:           %[[CST:.*]] = "tf.Const"() <{value = dense<18446744073709551615> : tensor<ui64>}> : () -> tensor<ui64>
// CHECK:           %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui64>, tensor<ui64>) -> tensor<7x9x11xui64>
// CHECK:           return %[[RES]] : tensor<7x9x11xui64>
// CHECK:         }
func.func @convert_not_ui64(%arg0: tensor<7x9x11xui64>) -> tensor<7x9x11xui64> {
  %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui64>) -> (tensor<7x9x11xui64>)
  func.return %0 : tensor<7x9x11xui64>
}

// -----

// CHECK-LABEL:  func @while_with_variadic() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<1> : tensor<i32>
// CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<0> : tensor<i32>
// CHECK-DAG:      %[[CST_2:.*]] = arith.constant dense<1000> : tensor<i32>
// CHECK:          %[[WHILEREGION_0:.*]]:3 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]]) <{is_stateless = false, parallel_iterations = 10 : i64}> ({
// CHECK:          ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
// CHECK:            %[[LESS_0:.*]] = "tf.Less"(%arg0, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK:            "tf.Yield"(%[[LESS_0]]) : (tensor<i1>) -> ()
// CHECK:          },  {
// CHECK:          ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
// CHECK:            %[[ADDV2_0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK:            "tf.Yield"(%[[ADDV2_0]], %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
// CHECK:          }) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
// CHECK:          return %[[WHILEREGION_0]]#0, %[[WHILEREGION_0]]#1, %[[WHILEREGION_0]]#2 : tensor<i32>, tensor<i32>, tensor<i32>
// CHECK:        }
func.func @while_with_variadic() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
  %cst = arith.constant dense<1> : tensor<i32>
  %cst_0 = arith.constant dense<0> : tensor<i32>
  %cst_1 = arith.constant dense<1000> : tensor<i32>
  %0:3 = "mhlo.while"(%cst_0, %cst, %cst_1) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
    %1 = "mhlo.compare"(%arg0, %arg2) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "mhlo.return"(%1) : (tensor<i1>) -> ()
  },  {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
    %1 = mhlo.add %arg0, %arg1 : tensor<i32>
    "mhlo.return"(%1, %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
  }) : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
  func.return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
}

// -----

// CHECK-LABEL:  func @while_with_reduce(%arg0: tensor<1x256xf32>, %arg1: tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>) {
// CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<1> : tensor<i32>
// CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<0> : tensor<i32>
// CHECK-DAG:      %[[CST_2:.*]] = arith.constant dense<1000> : tensor<i32>
// CHECK:          %[[WHILEREGION_0:.*]]:5 = "tf.WhileRegion"(%[[CST_1]], %[[CST_0]], %[[CST_2]], %arg0, %arg1) <{is_stateless = false, parallel_iterations = 10 : i64}> ({
// CHECK:          ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
// CHECK:            %[[LESS_0:.*]] = "tf.Less"(%arg2, %arg4) : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK:            "tf.Yield"(%[[LESS_0]]) : (tensor<i1>) -> ()
// CHECK:          },  {
// CHECK:          ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
// CHECK:            %[[ADDV2_0:.*]] = "tf.AddV2"(%arg2, %arg3) : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-DAG:        %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:        %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK:            %[[SUM_0:.*]] = "tf.Sum"(%arg5, %[[CONST_1]]) <{keep_dims = false}> : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK:            %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUM_0]], %arg6) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK:            "tf.Yield"(%[[ADDV2_0]], %arg3, %arg4, %arg5, %[[ADDV2_1]]) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> ()
// CHECK:          }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>)
// CHECK:          return %[[WHILEREGION_0]]#0, %[[WHILEREGION_0]]#1, %[[WHILEREGION_0]]#2, %[[WHILEREGION_0]]#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>
// CHECK:        }
func.func @while_with_reduce(%arg0: tensor<1x256xf32>, %arg1: tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>) {
  %cst = arith.constant dense<1> : tensor<i32>
  %cst_0 = arith.constant dense<0> : tensor<i32>
  %cst_1 = arith.constant dense<1000> : tensor<i32>
  %0:5 = "mhlo.while"(%cst_0, %cst, %cst_1, %arg0 , %arg1) ({
  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
    %1 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "mhlo.return"(%1) : (tensor<i1>) -> ()
  },  {
  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<1x256xf32>, %arg6: tensor<1xf32>):
    %1 = mhlo.add %arg2, %arg3 : tensor<i32>
    %2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = "mhlo.reduce"(%arg5, %2) ({
      ^bb0(%arg7: tensor<f32>, %arg8: tensor<f32>):
        %4 = mhlo.add %arg7, %arg8 : tensor<f32>
        "mhlo.return"(%4) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
    %5  = mhlo.add %3, %arg6 : tensor<1xf32>
    "mhlo.return"(%1, %arg3, %arg4, %arg5, %5) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> ()
  }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<1x256xf32>, tensor<1xf32>)
  func.return %0#0, %0#1, %0#2, %0#4: tensor<i32>, tensor<i32>, tensor<i32>, tensor<1xf32>
}

// -----

// CHECK-LABEL:  func @if
// CHECK-DAG:      %[[CST_0:.*]] = arith.constant dense<0> : tensor<i32>
// CHECK-DAG:      %[[CST_1:.*]] = arith.constant dense<1000> : tensor<i32>
// CHECK:          %[[RES:.*]]  = "tf.IfRegion"(%arg0) <{is_stateless = false}> ({
// CHECK:            "tf.Yield"(%[[CST_0]]) : (tensor<i32>) -> ()
// CHECK:          }, {
// CHECK:            "tf.Yield"(%[[CST_1]]) : (tensor<i32>) -> ()
// CHECK:          }) : (tensor<i1>) -> tensor<i32>
// CHECK:          return %[[RES]]
func.func @if(%arg0: tensor<i1>) -> (tensor<i32>) {
  %cst_0 = arith.constant dense<0> : tensor<i32>
  %cst_1 = arith.constant dense<1000> : tensor<i32>
  %1 = "mhlo.if"(%arg0) ({
    "mhlo.return"(%cst_0) : (tensor<i32>) -> ()
  }, {
    "mhlo.return"(%cst_1) : (tensor<i32>) -> ()
  }) : (tensor<i1>) -> tensor<i32>
  func.return %1: tensor<i32>
}

// CHECK-LABEL:   func @convert_dynamic_update_slice(
// CHECK-SAME:                                       %[[VAL_0:[a-z0-9]*]]: tensor<28x1x100xf32>,
// CHECK-SAME:                                       %[[VAL_1:[a-z0-9]*]]: tensor<1x1x100xf32>,
// CHECK-SAME:                                       %[[VAL_2:[a-z0-9]*]]: tensor<i32>,
// CHECK-SAME:                                       %[[VAL_3:[a-z0-9]*]]: tensor<i32>,
// CHECK-SAME:                                       %[[VAL_4:[a-z0-9]*]]: tensor<i32>) -> tensor<28x1x100xf32> {
// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<3xi32>) -> tensor<28x1x100xf32>
// CHECK:         return %1 : tensor<28x1x100xf32>
func.func @convert_dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> tensor<28x1x100xf32> {
  %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<28x1x100xf32>
  func.return %0 : tensor<28x1x100xf32>
}

// CHECK-LABEL:   func @dynamic_update_slice_inputs_have_dynamic_dim(
// CHECK-SAME:                                       %arg0: tensor<?x4xi32>,
// CHECK-SAME:                                       %arg1: tensor<?x2xi32>,
// CHECK-SAME:                                       %arg2: tensor<i32>,
// CHECK-SAME:                                       %arg3: tensor<i32>) -> tensor<?x4xi32> {
// CHECK:         %0 = "tf.Pack"(%arg2, %arg3) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
// CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<?x4xi32>, tensor<?x2xi32>, tensor<2xi32>) -> tensor<?x4xi32>
// CHECK:         return %1 : tensor<?x4xi32>
// CHECK:         }
func.func @dynamic_update_slice_inputs_have_dynamic_dim(%arg0: tensor<?x4xi32>, %arg1: tensor<?x2xi32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<?x4xi32> {
  %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3 : (tensor<?x4xi32>, tensor<?x2xi32>, tensor<i32>, tensor<i32>) -> tensor<?x4xi32>
  func.return %0 : tensor<?x4xi32>
}

// CHECK-LABEL:   func @dynamic_update_slice_operand_has_dynamic_dim(
// CHECK-SAME:                                       %arg0: tensor<1x?x256xf32>,
// CHECK-SAME:                                       %arg1: tensor<1x1x256xf32>,
// CHECK-SAME:                                       %arg2: tensor<i32>,
// CHECK-SAME:                                       %arg3: tensor<i32>,
// CHECK-SAME:                                       %arg4: tensor<i32>) -> tensor<1x?x256xf32> {
// CHECK:         %0 = "tf.Pack"(%arg2, %arg3, %arg4) <{axis = 0 : i64}> : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK:         %1 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %0) : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor<3xi32>) -> tensor<1x?x256xf32>
// CHECK:         return %1 : tensor<1x?x256xf32>
// CHECK:         }
func.func @dynamic_update_slice_operand_has_dynamic_dim(%arg0: tensor<1x?x256xf32>, %arg1: tensor<1x1x256xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> tensor<1x?x256xf32> {
  %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3, %arg4 : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x?x256xf32>
  func.return %0 : tensor<1x?x256xf32>
}

// CHECK-LABEL:   func @convert_reduce_to_all(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x2x3x4x5xi1>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
// CHECK-DAG:       %[[TRUE_CST:.*]] = "tf.Const"() <{value = dense<true> : tensor<i1>}> : () -> tensor<i1>
// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_0]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
// CHECK:           return %[[VAL_0:.*]] : tensor<2x4x5xi1>
// CHECK:         }
func.func @convert_reduce_to_all(%arg0: tensor<1x2x3x4x5xi1>, %arg1: tensor<2xi64>) -> tensor<2x4x5xi1> {
  %0 = mhlo.constant dense<true> : tensor<i1>
  %1 = "mhlo.reduce"(%arg0, %0) ({
    ^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
        %2 = mhlo.and %arg2, %arg3 : tensor<i1>
        "mhlo.return"(%2) : (tensor<i1>) -> ()
    }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<1x2x3x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
  func.return %1: tensor<2x4x5xi1>
}

// CHECK-LABEL:   func @convert_reduce_to_all_non_constant_init(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<i1>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<1x2x3x4x5xi1>,
// CHECK-SAME:                                %[[ARG_2:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_0:.*]] = "tf.All"(%[[ARG_1]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
// CHECK:           %[[VAL_1:.*]] = "tf.LogicalAnd"(%[[VAL_0]], %[[ARG_0]]) : (tensor<2x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
// CHECK:           return %[[VAL_1:.*]] : tensor<2x4x5xi1>
// CHECK:         }
func.func @convert_reduce_to_all_non_constant_init(%arg0: tensor<i1>, %arg1: tensor<1x2x3x4x5xi1>, %arg2: tensor<2xi64>) -> tensor<2x4x5xi1> {
  %0 = "mhlo.reduce"(%arg1, %arg0) ({
    ^bb0(%arg3: tensor<i1>, %arg4: tensor<i1>):
        %1 = mhlo.and %arg3, %arg4 : tensor<i1>
        "mhlo.return"(%1) : (tensor<i1>) -> ()
    }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<1x2x3x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
  func.return %0: tensor<2x4x5xi1>
}

// CHECK-LABEL:   func @convert_reduce_to_any(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<1x2x3x4x5xi1>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
// CHECK-DAG:       %[[FALSE_CST:.*]] = "tf.Const"() <{value = dense<false> : tensor<i1>}> : () -> tensor<i1>
// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_0]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
// CHECK:           return %[[VAL_0:.*]] : tensor<2x4x5xi1>
// CHECK:         }
func.func @convert_reduce_to_any(%arg0: tensor<1x2x3x4x5xi1>, %arg1: tensor<2xi64>) -> tensor<2x4x5xi1> {
  %0 = mhlo.constant dense<false> : tensor<i1>
  %1 = "mhlo.reduce"(%arg0, %0) ({
    ^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
        %2 = mhlo.or %arg2, %arg3 : tensor<i1>
        "mhlo.return"(%2) : (tensor<i1>) -> ()
    }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<1x2x3x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
  func.return %1: tensor<2x4x5xi1>
}

// CHECK-LABEL:   func @convert_reduce_to_any_non_constant_init(
// CHECK-SAME:                                %[[ARG_0:.*]]: tensor<i1>,
// CHECK-SAME:                                %[[ARG_1:.*]]: tensor<1x2x3x4x5xi1>,
// CHECK-SAME:                                %[[ARG_2:.*]]: tensor<2xi64>) -> tensor<2x4x5xi1> {
// CHECK-DAG:       %[[DIMENSIONS:.*]] = "tf.Const"() <{value = dense<[0, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
// CHECK:           %[[VAL_0:.*]] = "tf.Any"(%[[ARG_1]], %[[DIMENSIONS]]) <{keep_dims = false}> : (tensor<1x2x3x4x5xi1>, tensor<2xi64>) -> tensor<2x4x5xi1>
// CHECK:           %[[VAL_1:.*]] = "tf.LogicalOr"(%[[VAL_0]], %[[ARG_0]]) : (tensor<2x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
// CHECK:           return %[[VAL_1:.*]] : tensor<2x4x5xi1>
// CHECK:         }
func.func @convert_reduce_to_any_non_constant_init(%arg0: tensor<i1>, %arg1: tensor<1x2x3x4x5xi1>, %arg2: tensor<2xi64>) -> tensor<2x4x5xi1> {
  %0 = "mhlo.reduce"(%arg1, %arg0) ({
    ^bb0(%arg3: tensor<i1>, %arg4: tensor<i1>):
        %1 = mhlo.or %arg3, %arg4 : tensor<i1>
        "mhlo.return"(%1) : (tensor<i1>) -> ()
    }) {dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<1x2x3x4x5xi1>, tensor<i1>) -> tensor<2x4x5xi1>
  func.return %0: tensor<2x4x5xi1>
}

// CHECK-LABEL:   func @convert_sort_to_topk_iota_broadcast(
// CHECK-SAME:                                              %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VAL_3:.*]] = "tf.Range"(%cst, %cst_0, %cst_1) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<6xi32>
// CHECK:           %[[VAL_4:.*]] = arith.constant dense<[3, 6]> : tensor<2xi64>
// CHECK:           %[[VAL_5:.*]] = "tf.BroadcastTo"(%0, %cst_2) : (tensor<6xi32>, tensor<2xi64>) -> tensor<3x6xi32>
// CHECK:           %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
// CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
// CHECK:         }
func.func @convert_sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
  %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<6xi32>
  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
  %2:2 = "mhlo.sort"(%arg0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    "mhlo.return"(%3) : (tensor<i1>) -> ()
  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
  func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
}

// CHECK-LABEL:   func @convert_sort_to_topk_iotacst_broadcast(
// CHECK-SAME:                                                 %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>}> : () -> tensor<6xi32>
// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<[3, 6]> : tensor<2xi64>
// CHECK:           %[[VAL_2:.*]] = "tf.BroadcastTo"(%cst, %cst_0) : (tensor<6xi32>, tensor<2xi64>) -> tensor<3x6xi32>
// CHECK:           %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
// CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
// CHECK:         }
func.func @convert_sort_to_topk_iotacst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
  %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>
  %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32>
  %2:2 = "mhlo.sort"(%arg0, %1) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    "mhlo.return"(%3) : (tensor<i1>) -> ()
  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
  func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32>
}

// CHECK-LABEL:   func @convert_sort_to_topk_const(
// CHECK-SAME:                                     %[[ARG_0:.*]]: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
// CHECK-DAG:       %[[VAL_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<3x6xi32>}> : () -> tensor<3x6xi32>
// CHECK-DAG:       %[[K:.*]] = "tf.Const"() <{value = dense<6> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           %[[VALUES:.*]], %[[INDICES:.*]] = "tf.TopKV2"(%[[ARG_0]], %[[K]]) <{sorted = true}> : (tensor<3x6xf32>, tensor<i32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
// CHECK:           return %[[VALUES]], %[[INDICES]] : tensor<3x6xf32>, tensor<3x6xi32>
// CHECK:         }
func.func @convert_sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
  %0 = mhlo.constant dense<[[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]> : tensor<3x6xi32>
  %1:2 = "mhlo.sort"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
    %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    "mhlo.return"(%3) : (tensor<i1>) -> ()
  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
  func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32>
}

// CHECK-LABEL:   func @not_convert_sort_to_topk
// CHECK-NOT:       "tf.TopKV2"
func.func @not_convert_sort_to_topk(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) {
  %0 = mhlo.constant dense<[[0, 1, 2, 3, 4, 4], [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]> : tensor<3x6xi32>
  %1:2 = "mhlo.sort"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
    %4 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo<comparison_type TOTALORDER>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
    "mhlo.return"(%4) : (tensor<i1>) -> ()
  }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>)
  func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32>
}

// CHECK-LABEL:   func @convert_remainder_for_int32(
// CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<10x8xi32>,
// CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<10x8xi32>) -> tensor<10x8xi32> {
// CHECK:           %[[VAL_2:.*]] = "tf.Mod"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x8xi32>, tensor<10x8xi32>) -> tensor<10x8xi32>
// CHECK:           return %[[VAL_2]] : tensor<10x8xi32>
// CHECK:         }
func.func @convert_remainder_for_int32(%arg0: tensor<10x8xi32>, %arg1: tensor<10x8xi32>) -> tensor<10x8xi32> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi32>
  func.return %0 : tensor<10x8xi32>
}

// CHECK-LABEL:   func @convert_remainder_for_int64(
// CHECK-SAME:                                      %[[VAL_0:.*]]: tensor<10x8xi64>,
// CHECK-SAME:                                      %[[VAL_1:.*]]: tensor<10x8xi64>) -> tensor<10x8xi64> {
// CHECK:           %[[VAL_2:.*]] = "tf.Mod"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x8xi64>, tensor<10x8xi64>) -> tensor<10x8xi64>
// CHECK:           return %[[VAL_2]] : tensor<10x8xi64>
// CHECK:         }
func.func @convert_remainder_for_int64(%arg0: tensor<10x8xi64>, %arg1: tensor<10x8xi64>) -> tensor<10x8xi64> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi64>
  func.return %0 : tensor<10x8xi64>
}

// CHECK-LABEL:   func @not_convert_remainder_for_int16(
// CHECK-NOT:       "tf.Mod"
func.func @not_convert_remainder_for_int16(%arg0: tensor<10x8xi16>, %arg1: tensor<10x8xi16>) -> tensor<10x8xi16> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi16>
  func.return %0 : tensor<10x8xi16>
}

// CHECK-LABEL:   func @not_convert_remainder_for_uint16(
// CHECK-NOT:       "tf.Mod"
func.func @not_convert_remainder_for_uint16(%arg0: tensor<10x8xui16>, %arg1: tensor<10x8xui16>) -> tensor<10x8xui16> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xui16>
  func.return %0 : tensor<10x8xui16>
}

// CHECK-LABEL:   func @not_convert_remainder_for_uint32(
// CHECK-NOT:       "tf.Mod"
func.func @not_convert_remainder_for_uint32(%arg0: tensor<10x8xui32>, %arg1: tensor<10x8xui32>) -> tensor<10x8xui32> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xui32>
  func.return %0 : tensor<10x8xui32>
}

// CHECK-LABEL:   func @not_convert_remainder_for_uint64(
// CHECK-NOT:       "tf.Mod"
func.func @not_convert_remainder_for_uint64(%arg0: tensor<10x8xui64>, %arg1: tensor<10x8xui64>) -> tensor<10x8xui64> {
  %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xui64>
  func.return %0 : tensor<10x8xui64>
}

// CHECK-LABEL:   func @convert_population_count_i32(
// CHECK-SAME:                                   %[[ARG_0:.*]]: tensor<8xi32>
// CHECK:       %[[POP_CNT:.*]] = "tf.PopulationCount"(%[[ARG_0]]) : (tensor<8xi32>) -> tensor<8xui8>
// CHECK:       %[[RES:.*]] = "tf.Cast"(%[[POP_CNT]]) <{Truncate = false}> : (tensor<8xui8>) -> tensor<8xi32>
// CHECK:       return %[[RES]]
// CHECK:         }
func.func @convert_population_count_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
  %0 = "mhlo.popcnt"(%arg0) : (tensor<8xi32>) -> tensor<8xi32>
  func.return %0 : tensor<8xi32>
}

// CHECK-LABEL:   func @convert_population_count_ui8(
// CHECK-SAME:                                   %[[ARG_0:.*]]: tensor<8xui8>
// CHECK:       %[[POP_CNT:.*]] = "tf.PopulationCount"(%[[ARG_0]]) : (tensor<8xui8>) -> tensor<8xui8>
// CHECK:       return %[[POP_CNT]]
// CHECK:         }
func.func @convert_population_count_ui8(%arg0: tensor<8xui8>) -> tensor<8xui8> {
  %0 = "mhlo.popcnt"(%arg0) : (tensor<8xui8>) -> tensor<8xui8>
  func.return %0 : tensor<8xui8>
}

// CHECK-LABEL:   func @torch_index_select(
// CHECK:       %[[AXIS:.+]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK:       %[[RES:.+]] = "tf.GatherV2"(%arg0, %arg1, %[[AXIS]]) <{batch_dims = 0 : i64}>
// CHECK:       return %[[RES]]

func.func @torch_index_select(%arg0: tensor<2x1xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> {
  %0 = "mhlo.torch_index_select"(%arg0, %arg1) {
    batch_dims = 0 : i64, dim = 0 : i64
  } : (tensor<2x1xf32>, tensor<2xi32>) -> tensor<2x1xf32>
  func.return %0 : tensor<2x1xf32>
}

// CHECK-LABEL:   func @lowered_cumsum(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
// CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
// CHECK:         }
func.func @lowered_cumsum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.add %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
  func.return %1 : tensor<4x12xf32>
}

// CHECK-LABEL:   func @lowered_cumsum_trivial_attrs(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[VAL_3:.*]] = "tf.Cumsum"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
// CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
// CHECK:         }
func.func @lowered_cumsum_trivial_attrs(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
  %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.add %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
  func.return %1 : tensor<4x12xf32>
}

// CHECK-LABEL:   func @lowered_cumprod(
// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x12xf32>) -> tensor<4x12xf32> {
// CHECK-DAG:       %[[VAL_1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:       %[[VAL_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
// CHECK:           %[[VAL_3:.*]] = "tf.Cumprod"(%[[VAL_0]], %[[VAL_2]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor<i64>) -> tensor<4x12xf32>
// CHECK:           return %[[VAL_3]] : tensor<4x12xf32>
// CHECK:         }
func.func @lowered_cumprod(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.multiply %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 0], [11, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[1, 12]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
  func.return %1 : tensor<4x12xf32>
}

// CHECK-LABEL: reduce_window_trivial_window_dims
// CHECK-NOT:       "tf.Cumprod"
func.func @reduce_window_trivial_window_dims(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> {
  %0 = mhlo.constant dense<1.000000e+00> : tensor<f32>
  %1 = "mhlo.reduce_window"(%arg0, %0) ({
  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
    %2 = mhlo.multiply %arg1, %arg2 : tensor<f32>
    "mhlo.return"(%2) : (tensor<f32>) -> ()
  }) {padding = dense<0> : tensor<2x2xi64>, window_dimensions = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor<f32>) -> tensor<4x12xf32>
  func.return %1 : tensor<4x12xf32>
}

// CHECK-LABEL:   func @const_quant
// CHECK-NOT:       "tf.Const"
func.func @const_quant() -> tensor<512x1x!quant.uniform<i8:f32, 0.013133913278579712>> {
  %0 = mhlo.constant() {value = dense<0> : tensor<512x1xi8>} : () -> tensor<512x1x!quant.uniform<i8:f32, 0.013133913278579712>>
  func.return %0 : tensor<512x1x!quant.uniform<i8:f32, 0.013133913278579712>>
}

// CHECK-LABEL:   func @convert_dot_quant_type(
// CHECK-NOT:       "tf.BatchMatMulV3"
func.func @convert_dot_quant_type(%arg0: tensor<1x256xf32>, %arg1: tensor<256x!quant.uniform<i8:f32, 1.0>>) -> tensor<1xf32> {
  %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x256xf32>, tensor<256x!quant.uniform<i8:f32, 1.0>>) -> tensor<1xf32>
  func.return %0 : tensor<1xf32>
}

// CHECK-LABEL: func @get_dimension_size(
// CHECK-SAME:              %[[ARG_0:.*]]: tensor<4x256x?xf32>) -> tensor<i32> {
// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<256> : tensor<i32>}> : () -> tensor<i32>
// CHECK          return %[[CST_0]] : tensor<i32>
func.func @get_dimension_size(%arg0: tensor<4x256x?xf32>) -> tensor<i32> {
  %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor<4x256x?xf32>) -> tensor<i32>
  func.return %0 : tensor<i32>
}

// CHECK-LABEL: func @get_dimension_size_dynamic(
// CHECK-SAME:              %[[ARG_0:.*]]: tensor<4x256x?xf32>) -> tensor<i32> {
// CHECK          %[[VAL_0:.*]] = "tf.Shape"(%[[ARG_0]]) : (tensor<4x256x?xf32>) -> tensor<3xi32>
// CHECK          %[[CST_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK          %[[CST_1:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK          %[[VAL_1:.*]] = "tf.Slice"(%[[VAL_0]], %[[CST_1]], %[[CST_0]]) : (tensor<3xi32>, tensor<1xi64>, tensor<1xi32>) -> tensor<1xi32>
// CHECK          %[[VAL_2:.*]] = "tf.Squeeze"(%[[VAL_1]]) <{squeeze_dims = [0]}> : (tensor<1xi32>) -> tensor<i32>
// CHECK          return %[[VAL_2]] : tensor<i32>
func.func @get_dimension_size_dynamic(%arg0: tensor<4x256x?xf32>) -> tensor<i32> {
  %0 = "mhlo.get_dimension_size"(%arg0) <{dimension = 2 : i64}> : (tensor<4x256x?xf32>) -> tensor<i32>
  func.return %0 : tensor<i32>
}

// CHECK-LABEL: func @dynamic_iota_i32_1d(
// CHECK-SAME:                  %[[ARG_0:.*]]: tensor<1xi32>) -> tensor<?xi32> {
// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
// CHECK:         %[[VAL_0:.*]] = "tf.Reshape"(%arg0, %[[CST_0]]) : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<i32>}> : () -> tensor<i32>
// CHECK:         %[[VAL_1:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_0]], %[[CST_2]]) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK:         return %[[VAL_1]] : tensor<?xi32>
func.func @dynamic_iota_i32_1d(%arg0: tensor<1xi32>) -> tensor<?xi32> {
  %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xi32>) -> tensor<?xi32>
  func.return %0 : tensor<?xi32>
}

// CHECK-LABEL: func @dynamic_iota_f32_1d(
// CHECK-SAME:                  %[[ARG_0:.*]]: tensor<1xi32>) -> tensor<?xf32> {
// CHECK:         %[[VAL_0:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xf32>
// CHECK-DAG:     %[[CST_0:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32>
// CHECK:         %[[VAL_1:.*]] = "tf.Reshape"(%[[VAL_0]], %[[CST_0]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
// CHECK-DAG:     %[[CST_1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG:     %[[CST_2:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK:         %[[VAL_2:.*]] = "tf.Range"(%[[CST_1]], %[[VAL_1]], %[[CST_2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
// CHECK:         return %[[VAL_2]] : tensor<?xf32>
func.func @dynamic_iota_f32_1d(%arg0: tensor<1xi32>) -> tensor<?xf32> {
  %0 = "mhlo.dynamic_iota"(%arg0) <{iota_dimension = 0 : i64}> : (tensor<1xi32>) -> tensor<?xf32>
  func.return %0 : tensor<?xf32>
}

// CHECK-LABEL: func @real_dynamic_slice_strides_equal_to_1_signed(
// CHECK-SAME:              %arg0: tensor<1x?x4x256xf32>,
// CHECK-SAME:              %arg1: tensor<4xi32>,
// CHECK-SAME:              %arg2: tensor<4xi32>) -> tensor<1x?x4x128xf32> {
// CHECK:         %cst = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK:         %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x?x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x4x128xf32>
// CHECK:         return %0 : tensor<1x?x4x128xf32>
func.func @real_dynamic_slice_strides_equal_to_1_signed(%arg0: tensor<1x?x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>) -> tensor<1x?x4x128xf32> {
%cst = mhlo.constant dense<1> : tensor<4xi32>
%0 = mhlo.real_dynamic_slice %arg0, %arg1, %arg2, %cst : (tensor<1x?x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x4x128xf32>
func.return %0 : tensor<1x?x4x128xf32>
}

// CHECK-LABEL: func @real_dynamic_slice_strides_not_equal_to_1(
// CHECK-SAME:              %arg0: tensor<1x?x2x4xf32>,
// CHECK-SAME:              %arg1: tensor<4xi32>,
// CHECK-SAME:              %arg2: tensor<4xi32>) -> tensor<1x?x1x2xf32> {
// CHECK          %cst = "tf.Const"() <{value = dense<2> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK          %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %cst) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64}> : (tensor<1x?x2x4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x1x2xf32>
// CHECK          return %0 : tensor<1x?x1x2xf32>
func.func @real_dynamic_slice_strides_not_equal_to_1(%arg0: tensor<1x?x2x4xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>) -> tensor<1x?x1x2xf32> {
%cst = mhlo.constant dense<2> : tensor<4xi32>
%0 = mhlo.real_dynamic_slice %arg0, %arg1, %arg2, %cst : (tensor<1x?x2x4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x?x1x2xf32>
func.return %0 : tensor<1x?x1x2xf32>
}

// CHECK-LABEL:   func @remove_shape_assertion_custom_call
// CHECK-NOT:       "mhlo.custom_call"
func.func @remove_shape_assertion_custom_call(%arg1: tensor<?x5xi32>) -> tensor<i32> {
  %0 = mhlo.constant dense<3> : tensor<i32>
  %1 = "mhlo.get_dimension_size"(%arg1) <{dimension = 0 : i64}> : (tensor<?x5xi32>) -> tensor<i32>
  %ok = mhlo.compare  EQ, %1, %0,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
  mhlo.custom_call @shape_assertion(%ok) {
    error_message = "The error message",
    has_side_effect = true
  } : (tensor<i1>) -> ()
  return %1 : tensor<i32>
}

// CHECK-LABEL: func @convert_approx_top_k_custom_call(
// CHECK-SAME:                                        %[[ARG_0:.*]]: tensor<1x4xf32>,
// CHECK-SAME:                                        %[[ARG_1:.*]]: tensor<1x4xi32>,
// CHECK-SAME:                                        %[[ARG_2:.*]]: tensor<f32>,
// CHECK-SAME:                                        %[[ARG_3:.*]]: tensor<i32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) {
// CHECK:          %[[VALUES:.*]], %[[INDICES:.*]] = "tf.ApproxTopK"(%[[ARG_0]]) <{aggregate_to_topk = true, is_max_k = true, k = 4 : i64, recall_target = 8.500000e-01 : f32, reduction_dimension = 1 : i64, reduction_input_size_override = -1 : i64}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
// CHECK:          return %[[VALUES]], %[[INDICES]] : tensor<1x4xf32>, tensor<1x4xi32>
// CHECK:        }
func.func @convert_approx_top_k_custom_call(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>, %arg2: tensor<f32>, %arg3: tensor<i32>) -> (tensor<1x4xf32>, tensor<1x4xi32>) {
  %0:2 = mhlo.custom_call @ApproxTopK(%arg0, %arg1, %arg2, %arg3) {
    api_version = 4 : i32,
    called_computations = [@top_k_gt_f32_comparator],
    backend_config = {
      aggregate_to_topk = true,
      is_fallback = true,
      recall_target = 8.500000e-01 : f32,
      reduction_dim = 1 : i64,
      reduction_input_size_override = -1 : i64,
      top_k = 4 : i64}
    } : (tensor<1x4xf32>, tensor<1x4xi32>, tensor<f32>, tensor<i32>) -> (tensor<1x4xf32>, tensor<1x4xi32>)
  func.return %0#0, %0#1 : tensor<1x4xf32>, tensor<1x4xi32>
}

func.func @top_k_gt_f32_comparator(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
  %0 = mhlo.compare  GT, %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
  func.return %0 : tensor<i1>
}
