load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "if_google")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")

package(
    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    includes = [
        "//xla:friends",
    ],
)

cc_library(
    name = "fusion",
    srcs = ["fusion.cc"],
    hdrs = ["fusion.h"],
    visibility = [
        "//xla/backends/gpu/codegen:__subpackages__",
        "//xla/service/gpu:__subpackages__",
    ],
    deps = [
        ":xtile_compiler",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla/backends/gpu/codegen:fusion_emitter",
        "//xla/backends/gpu/runtime:kernel_thunk",
        "//xla/backends/gpu/runtime:thunk",
        "//xla/codegen/emitters:kernel_arguments",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_constants",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:ir_emitter_context",
        "//xla/service/gpu:kernel_reuse_cache",
        "//xla/service/gpu:launch_dimensions",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Support",
    ],
)

xla_cc_test(
    name = "fusion_test",
    srcs = ["fusion_test.cc"],
    tags = ["gpu"],
    deps = [
        ":fusion",
        "//xla/backends/gpu/codegen:fusion_emitter",
        "//xla/backends/gpu/codegen:fusions",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:target_constants",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/tests:xla_internal_test_main",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
    ],
)

cc_library(
    name = "emitter_helpers",
    srcs = ["emitter_helpers.cc"],
    hdrs = [
        "emitter_helpers.h",
    ],
    compatible_with = get_compatible_with_portable(),
    deps = [
        "//xla:comparison_util",
        "//xla:literal",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/codegen/emitters:elemental_hlo_to_mlir",
        "//xla/codegen/tiling:tiled_hlo_instruction",
        "//xla/codegen/xtile/ir:xtile",
        "//xla/hlo/analysis:indexing_analysis",
        "//xla/hlo/ir:hlo",
        "//xla/mlir_hlo",
        "//xla/mlir_hlo:map_mhlo_to_scalar_op",
        "//xla/mlir_hlo:transformation_helpers",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:ArithUtils",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@stablehlo//:stablehlo_ops",
    ],
)

cc_library(
    name = "lowering_util",
    srcs = ["lowering_util.cc"],
    hdrs = ["lowering_util.h"],
    deps = [
        ":tma_utils",
        "//xla:util",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:LLVMDialect",
    ],
)

xla_cc_test(
    name = "lowering_util_test",
    srcs = ["lowering_util_test.cc"],
    deps = [
        ":lowering_util",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log:check",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:LLVMDialect",
        "@llvm-project//mlir:Parser",
        "@triton//:TritonDialects",
    ],
)

cc_library(
    name = "compilation_pipeline",
    srcs = [
        "compilation_pipeline.cc",
        "compilation_pipeline_cuda.cc",
        "compilation_pipeline_rocm.cc",
    ],
    hdrs = ["compilation_pipeline.h"],
    deps = [
        "//xla/backends/gpu/codegen/emitters/transforms:passes",
        "//xla/backends/gpu/codegen/triton/transforms:passes",
        "//xla/codegen/emitters/transforms:convert_pure_call_ops_pass",
        "//xla/codegen/emitters/transforms:passes",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@llvm-project//mlir:AffineToStandard",
        "@llvm-project//mlir:ArithToLLVM",
        "@llvm-project//mlir:ControlFlowToLLVM",
        "@llvm-project//mlir:IndexToLLVM",
        "@llvm-project//mlir:NVVMToLLVM",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:SCFToControlFlow",
        "@llvm-project//mlir:Transforms",
        "@triton//:GluonTransforms",
        "@triton//:TritonDialects",
        "@triton//:TritonGPUToLLVM",
        "@triton//:TritonGPUTransforms",
        "@triton//:TritonInstrumentTransforms",
        "@triton//:TritonLLVMIR",
        "@triton//:TritonNvidiaGPUTransforms",
        "@triton//:TritonToTritonGPU",
        "@triton//:TritonToTritonGPUPasses",
        "@triton//:TritonTransforms",
        "@triton//:WarpSpecialization",
        "@triton//third_party/amd:TritonAMDGPUToLLVM",
        "@triton//third_party/amd:TritonAMDGPUTransforms",
        "@triton//third_party/nvidia:NVGPUToLLVM",
        "@triton//third_party/nvidia:NVHopperTransforms",
        "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
    ],
)

xla_cc_test(
    name = "compilation_pipeline_test",
    srcs = ["compilation_pipeline_test.cc"],
    tags = ["gpu"],
    deps = [
        ":compilation_pipeline",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
    ],
)

cc_library(
    name = "fusion_emitter",
    srcs = ["fusion_emitter.cc"],
    hdrs = ["fusion_emitter.h"],
    deps = [
        ":collective_emitter",  # TODO(willfroom): Migrate to using stablehlo.allreduce etc.
        ":dot_algorithms",
        ":emitter_helpers",
        "//xla:autotuning_proto_cc",
        "//xla:permutation_util",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/codegen/emitters:elemental_hlo_to_mlir",
        "//xla/codegen/emitters/ir:xla",
        "//xla/codegen/tiling:symbolic_tile_analysis",
        "//xla/codegen/tiling:tiled_hlo_computation",
        "//xla/codegen/tiling:tiled_hlo_fusion_instruction",
        "//xla/codegen/tiling:tiled_hlo_instruction",
        "//xla/codegen/tiling:tiled_hlo_schedule",
        "//xla/codegen/tiling:tiling_specification",
        "//xla/codegen/xtile/ir:xtile",
        "//xla/codegen/xtile/ir/transforms:passes",
        "//xla/hlo/analysis:indexing_analysis",
        "//xla/hlo/builder:xla_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer",
        "//xla/mlir_hlo",
        "//xla/service:hlo_module_config",
        "//xla/service:instruction_fusion",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:AffineToStandard",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
        "@llvm-project//mlir:FunctionInterfaces",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
        "@llvm-project//mlir:SCFDialect",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@llvm-project//mlir:ToLLVMIRTranslation",
        "@local_tsl//tsl/platform:path",
        "@stablehlo//:stablehlo_ops",
    ],
)

cc_library(
    name = "xtile_compiler",
    srcs =
        [
            "xtile_compiler.cc",
        ],
    hdrs = ["xtile_compiler.h"],
    tags = [
        "no-oneapi",
    ],
    deps = [
        ":collective_emitter",
        ":compilation_pipeline",
        ":fusion_emitter",
        ":lowering_util",
        ":support",
        "//xla:autotuning_proto_cc",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/backends/gpu/codegen/triton/transforms:passes",
        "//xla/codegen:ir_printing",
        "//xla/codegen/emitters/ir:xla",
        "//xla/codegen/emitters/transforms:passes",
        "//xla/codegen/tiling:symbolic_tile_analysis",
        "//xla/codegen/tiling:tiling_specification",
        "//xla/codegen/xtile/ir:xtile",
        "//xla/codegen/xtile/ir/transforms:passes",
        "//xla/hlo/analysis:symbolic_expr",
        "//xla/hlo/builder:xla_builder",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer",
        "//xla/service:dump",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/service/gpu/model:triton_emitter_constraints",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor:launch_dim",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tools:hlo_decomposer_lib",
        "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Linker",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:AffineDialect",
        "@llvm-project//mlir:AffineToStandard",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:ArithToLLVM",
        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
        "@llvm-project//mlir:ControlFlowToLLVM",
        "@llvm-project//mlir:ExecutionEngineUtils",
        "@llvm-project//mlir:FuncDialect",
        "@llvm-project//mlir:FuncExtensions",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:IndexToLLVM",
        "@llvm-project//mlir:LLVMDialect",
        "@llvm-project//mlir:LLVMIRTransforms",
        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
        "@llvm-project//mlir:NVVMDialect",
        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
        "@llvm-project//mlir:Pass",
        "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@llvm-project//mlir:ToLLVMIRTranslation",
        "@llvm-project//mlir:Transforms",
        "@local_tsl//tsl/platform:path",
        "@stablehlo//:stablehlo_ops",
        "@triton//:TritonDialects",
        "@triton//:TritonTransforms",
    ],
)

cc_library(
    name = "dot_algorithms",
    srcs = ["dot_algorithms.cc"],
    hdrs = ["dot_algorithms.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        ":emitter_helpers",
        "//xla:xla_data_proto_cc",
        "//xla/codegen/xtile/ir:xtile",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/translate/hlo_to_mhlo:attribute_importer",
        "//xla/service:algorithm_util",
        "//xla/service/llvm_ir:llvm_util",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:MathDialect",
        "@llvm-project//mlir:Support",
        "@local_tsl//tsl/platform:tensor_float_32_hdr_lib",
        "@stablehlo//:stablehlo_ops",
        "@triton//:TritonDialects",
    ],
)

xla_cc_test(
    name = "fusion_emitter_deviceless_test",
    srcs = ["fusion_emitter_deviceless_test.cc"],
    tags = ["no_oss"],  # Doesn't pass in OSS when building with the `fusion_emitter_stub`.
    deps = [
        ":xtile_compiler",
        "//xla:xla_proto_cc",
        "//xla/hlo/analysis:symbolic_expr",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:target_constants",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:xla_internal_test_main",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
    ],
)

xla_test(
    name = "triton_gemm_fusion_test",
    srcs = ["triton_gemm_fusion_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        ":test_utils",
        ":xtile_compiler",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/analysis:symbolic_expr",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:pattern_matcher_gmock",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:pattern_matcher",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:target_constants",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/service/gpu/transforms:hoist_fused_bitcasts",
        "//xla/service/gpu/transforms:nest_gemm_fusion",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_int4_device_test",
    size = "large",
    srcs = ["fusion_emitter_int4_device_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 10,
    tags = [
        "large",
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "dot_algorithms_test",
    srcs = ["dot_algorithms_test.cc"],
    backend_args = if_google(
        {
            "b200": ["--heap_check="],
            "a100": ["--heap_check="],
            "h100": ["--heap_check="],
        },
        {},
    ),
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    env = {
        "CUBLAS_EMULATE_SINGLE_PRECISION": "1",  # Trigger single precision emulation (F32_F32_F32) with BF16x9 cublas algorithm. It was introduced in cublas 12.9.
        "CUBLAS_EMULATION_STRATEGY": "performant",  # Trigger single precision emulation (F32_F32_F32) with BF16x9 cublas algorithm. It was introduced in cublas 12.9.
    },
    shard_count = 30,
    tags = [
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        ":test_utils",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/backends/gpu/profiler:kernel_name_tracer",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:dump",
        "//xla/service:hlo_module_config",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/time",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest",
        "@llvm-project//llvm:Support",
        "@local_tsl//tsl/platform:path",
    ],
)

xla_test(
    name = "fusion_emitter_device_test",
    srcs = ["fusion_emitter_device_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    shard_count = 10,
    tags = [
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        ":support",
        ":test_utils",
        ":xtile_compiler",
        "//xla:autotuning_proto_cc",
        "//xla:error_spec",
        "//xla:literal",
        "//xla:literal_util",
        "//xla:shape_util",
        "//xla:types",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/analysis:symbolic_expr",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:algorithm_util",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:target_constants",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/stream_executor/rocm:rocm_compute_capability",
        "//xla/tests:test_utils",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:env",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
        "@eigen_archive//:eigen3",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:Pass",
        "@local_tsl//tsl/platform:path",
    ],
)

cc_library(
    name = "test_utils",
    testonly = True,
    srcs = ["test_utils.cc"],
    hdrs = ["test_utils.h"],
    deps = [
        ":fusion_emitter",
        ":xtile_compiler",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/pass:hlo_pass_pipeline",
        "//xla/hlo/testlib:filecheck",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/hlo/transforms/simplifiers:float_normalization",
        "//xla/hlo/utils:hlo_query",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:gpu_float_support",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:matmul_utils",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/service/gpu/model:triton_emitter_constraints",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tests:hlo_test_base",
        "//xla/tests:hlo_test_base_with_mlir_context",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_for_library",
        "@llvm-project//llvm:Support",
        "@llvm-project//llvm:TargetParser",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

xla_test(
    name = "fusion_emitter_large_test",
    size = "large",
    srcs = ["fusion_emitter_large_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "large",
        "no_mac",
        "no_oss",  # requires-mem:16g tag doesn't work in open source
        "nozapfhahn",  # Times out under coverage
        "pjrt_migration_candidate",
    ] + if_google([
        "requires-mem:16g",
    ]),
    use_legacy_runtime = True,
    deps = [
        "//xla:error_spec",
        "//xla:xla_proto_cc",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/tests:hlo_test_base",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest",
    ],
)

xla_test(
    name = "fusion_emitter_parametrized_test",
    srcs = ["fusion_emitter_parametrized_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        ":support",
        ":test_utils",
        "//xla:comparison_util",
        "//xla:error_spec",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu/tests:gpu_codegen_test",
        "//xla/stream_executor:device_description",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/base:core_headers",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest",
    ],
)

xla_cc_test(
    name = "fusion_emitter_shared_dialect_test",
    srcs = ["fusion_emitter_shared_dialect_test.cc"],
    # TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS.
    # Force it to run on GPU temporarily in order to get important OSS coverage.
    tags = [
        "gpu",
        "no_mac",
        "pjrt_migration_candidate",
    ],
    deps = [
        ":test_utils",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/parser:hlo_parser",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/tests:hlo_test_base_with_mlir_context",
        "//xla/tests:xla_internal_test_main",  # fixdeps: keep
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest",
    ],
)

cc_library(
    name = "support",
    srcs = [
        "support.cc",
        "support_legacy.cc",
    ],
    hdrs = [
        "support.h",
        "support_legacy.h",
    ],
    deps = [
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service:algorithm_util",
        "//xla/service:instruction_fusion",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/functional:overload",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@local_tsl//tsl/platform:tensor_float_32_utils",
    ],
)

xla_cc_test(
    name = "support_test",
    srcs = ["support_test.cc"],
    shard_count = 25,
    # TODO(b/353912594): this test does not need to run on GPU, but it is broken on CPU in OSS.
    # Force it to run on GPU temporarily in order to get important OSS coverage.
    tags = [
        "gpu",
        "pjrt_migration_candidate",
    ],
    deps = [
        ":fusion_emitter",
        ":support",
        ":test_utils",
        ":xtile_compiler",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:target_constants",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/algorithm:container",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_absl//absl/types:span",
        "@com_google_googletest//:gtest_main",
        "@local_tsl//tsl/platform:protobuf",
    ],
)

xla_test(
    name = "support_legacy_test",
    srcs = ["support_legacy_test.cc"],
    backends = [
        "a100",
        "h100",
        "b200",
        "amdgpu_any",
    ],
    tags = [
        "no_mac",
        "pjrt_migration_candidate",
    ],
    use_legacy_runtime = True,
    deps = [
        ":fusion_emitter",
        ":support",
        ":test_utils",
        ":xtile_compiler",
        "//xla:error_spec",
        "//xla:shape_util",
        "//xla:xla_data_proto_cc",
        "//xla:xla_proto_cc",
        "//xla/hlo/ir:hlo",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:triton_fusion_analysis",
        "//xla/service/gpu/model:block_level_parameters",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/cuda:cuda_compute_capability",
        "//xla/tsl/lib/core:status_test_util",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_absl//absl/strings",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "tma_utils",
    srcs = ["tma_utils.cc"],
    hdrs = ["tma_utils.h"],
    deps = [
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/service/gpu:matmul_utils",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@llvm-project//llvm:Support",
    ],
)

cc_library(
    name = "collective_emitter",
    srcs = ["collective_emitter.cc"],
    hdrs = ["collective_emitter.h"],
    deps = [
        ":emitter_helpers",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:types",
        "//xla:util",
        "//xla:xla_data_proto_cc",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/backends/gpu/runtime:all_reduce",
        "//xla/codegen/tiling:tiled_hlo_instruction",
        "//xla/codegen/xtile/ir:xtile",
        "//xla/hlo/ir:hlo",
        "//xla/service:collective_ops_utils",
        "//xla/service/gpu:backend_configs_cc",
        "//xla/service/gpu:ir_emission_utils",
        "//xla/service/gpu:launch_dimensions",
        "//xla/stream_executor:device_description",
        "//xla/stream_executor/gpu:all_reduce_kernel",
        "//xla/tsl/platform:errors",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/base",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/functional:any_invocable",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:ArithDialect",
        "@llvm-project//mlir:FunctionInterfaces",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:NVVMDialect",
        "@llvm-project//mlir:Support",
        "@llvm-project//mlir:TensorDialect",
        "@stablehlo//:stablehlo_ops",
        "@triton//:TritonDialects",
    ],
)

xla_cc_test(
    name = "collective_emitter_test",
    srcs = ["collective_emitter_test.cc"],
    tags = ["gpu"],
    deps = [
        ":collective_emitter",
        ":fusion",
        ":xtile_compiler",
        "//xla:shape_util",
        "//xla:status_macros",
        "//xla:xla_data_proto_cc",
        "//xla/backends/gpu/codegen:fusion_emitter",
        "//xla/backends/gpu/codegen:fusions",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/utils:hlo_query",
        "//xla/service:hlo_creation_utils",
        "//xla/service/gpu:gpu_device_info_for_tests",
        "//xla/service/gpu:hlo_fusion_analysis",
        "//xla/stream_executor:device_description",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/util/proto:proto_matchers",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/strings:str_format",
        "@com_google_absl//absl/strings:string_view",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:ir_headers",
        "@llvm-project//mlir:IR",
    ],
)

xla_cc_test(
    name = "tma_utils_test",
    srcs = ["tma_utils_test.cc"],
    deps = [
        ":tma_utils",
        "//xla/backends/gpu/codegen/triton/ir:triton_xla",
        "//xla/stream_executor/gpu:tma_metadata",
        "//xla/tsl/platform:statusor",
        "@com_google_absl//absl/status",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
    ],
)

cc_library(
    name = "tiled_emitter_constraints",
    srcs = ["tiled_emitter_constraints.cc"],
    hdrs = ["tiled_emitter_constraints.h"],
    compatible_with = get_compatible_with_portable(),
    deps = [
        ":emitter_helpers",
        "//xla:util",
        "//xla/codegen/tiling:affine_map_evaluator",
        "//xla/codegen/tiling:constraint_expression",
        "//xla/codegen/tiling:symbolic_tile",
        "//xla/codegen/tiling:symbolic_tile_analysis",
        "//xla/codegen/tiling:symbolic_tiled_hlo_instruction",
        "//xla/hlo/analysis:indexing_analysis",
        "//xla/hlo/analysis:interval",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/utils:hlo_traversal",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/log:check",
        "@com_google_absl//absl/memory",
        "@com_google_absl//absl/status:statusor",
        "@com_google_absl//absl/strings",
        "@com_google_absl//absl/types:span",
        "@llvm-project//llvm:Support",
        "@llvm-project//mlir:IR",
    ],
)

xla_cc_test(
    name = "tiled_emitter_constraints_test",
    srcs = ["tiled_emitter_constraints_test.cc"],
    deps = [
        ":tiled_emitter_constraints",
        "//xla/codegen/tiling:symbolic_tile_analysis",
        "//xla/codegen/tiling:tiling_specification",
        "//xla/hlo/analysis:symbolic_expr",
        "//xla/hlo/ir:hlo",
        "//xla/hlo/testlib:hlo_hardware_independent_test_base",
        "//xla/hlo/testlib:verified_hlo_module",
        "//xla/service:instruction_fusion",
        "//xla/tsl/platform:statusor",
        "//xla/tsl/platform:test",
        "@com_google_absl//absl/log",
        "@com_google_absl//absl/status:status_matchers",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//mlir:IR",
    ],
)
