How to handle OOM at TFX Transform component?

I have constructed a TFX pipeline but as I have increased the number of audio samples from 350 to 5000 I got OOM at the Transform component. I haven’t found parameters through which I can handle this problem. Here is my pipeline:


import os
import tensorflow as tf
from tfx.components import Transform, Trainer
from tfx.components import Evaluator, Pusher
from tfx.dsl.components.common.resolver import Resolver
from tfx.components import ExampleValidator, SchemaGen, StatisticsGen
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import example_gen_pb2
from tfx.proto import trainer_pb2
from tfx.components import FileBasedExampleGen
from tfx.dsl.components.base import executor_spec
from CustomExampleGen import *
from sys import exit
import tensorflow_data_validation as tfdv

context = InteractiveContext()

# Update the input_config to include train and eval patterns
input_config = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='train', pattern='train/*.pcap'),
    example_gen_pb2.Input.Split(name='eval', pattern='eval/*.pcap'),
])

#ADICIONAR OUTPUT CONFIG
example_gen = FileBasedExampleGen(
    input_base="/home/marlon/Área de Trabalho/telnyx/audio_classification/data/",
    input_config=input_config,
    custom_executor_spec=executor_spec.ExecutorClassSpec(BaseExampleGenExecutor)
)

artifact_store_path = "/home/marlon/Área de Trabalho/telnyx/audio_classification"

# Initialize the InteractiveContext with the artifact store path
context = InteractiveContext(pipeline_root=artifact_store_path)

context.run(example_gen)

statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)
statistics_artifact = statistics_gen.outputs['statistics'].get()[0]
print(statistics_artifact)
statistics_artifact_uri = statistics_gen.outputs['statistics'].get()[0].uri

schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen)

example_validator = ExampleValidator(statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema'])
context.run(example_validator)

beam_args = [
    '--direct_num_workers=1',
]


transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath("preprocessing.py"),
    force_tf_compat_v1=True # Adjust 512 to the desired batch size
).with_beam_pipeline_args(beam_args)
context.run(transform)

DESIRED_EPOCHS = 130

trainer = Trainer(
    module_file=os.path.abspath("module.py"),
    transformed_examples=transform.outputs['transformed_examples'],
    schema=schema_gen.outputs['schema'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=300),
    eval_args=trainer_pb2.EvalArgs(num_steps=80),  # similarly calculate eval_steps if needed
    custom_config={'epochs': DESIRED_EPOCHS}  # Pass epochs as a custom config.
)
context.run(trainer)

I decided to use TFX because I thought it was supposed to handle fine large amounts of data…

Are you still having this issue? If you have reproducible code and data, I can try and help you with this.