I have constructed a TFX pipeline but as I have increased the number of audio samples from 350 to 5000 I got OOM at the Transform component. I haven’t found parameters through which I can handle this problem. Here is my pipeline:
import os
import tensorflow as tf
from tfx.components import Transform, Trainer
from tfx.components import Evaluator, Pusher
from tfx.dsl.components.common.resolver import Resolver
from tfx.components import ExampleValidator, SchemaGen, StatisticsGen
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import example_gen_pb2
from tfx.proto import trainer_pb2
from tfx.components import FileBasedExampleGen
from tfx.dsl.components.base import executor_spec
from CustomExampleGen import *
from sys import exit
import tensorflow_data_validation as tfdv
context = InteractiveContext()
# Update the input_config to include train and eval patterns
input_config = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='train/*.pcap'),
example_gen_pb2.Input.Split(name='eval', pattern='eval/*.pcap'),
])
#ADICIONAR OUTPUT CONFIG
example_gen = FileBasedExampleGen(
input_base="/home/marlon/Área de Trabalho/telnyx/audio_classification/data/",
input_config=input_config,
custom_executor_spec=executor_spec.ExecutorClassSpec(BaseExampleGenExecutor)
)
artifact_store_path = "/home/marlon/Área de Trabalho/telnyx/audio_classification"
# Initialize the InteractiveContext with the artifact store path
context = InteractiveContext(pipeline_root=artifact_store_path)
context.run(example_gen)
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)
statistics_artifact = statistics_gen.outputs['statistics'].get()[0]
print(statistics_artifact)
statistics_artifact_uri = statistics_gen.outputs['statistics'].get()[0].uri
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen)
example_validator = ExampleValidator(statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema'])
context.run(example_validator)
beam_args = [
'--direct_num_workers=1',
]
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath("preprocessing.py"),
force_tf_compat_v1=True # Adjust 512 to the desired batch size
).with_beam_pipeline_args(beam_args)
context.run(transform)
DESIRED_EPOCHS = 130
trainer = Trainer(
module_file=os.path.abspath("module.py"),
transformed_examples=transform.outputs['transformed_examples'],
schema=schema_gen.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=300),
eval_args=trainer_pb2.EvalArgs(num_steps=80), # similarly calculate eval_steps if needed
custom_config={'epochs': DESIRED_EPOCHS} # Pass epochs as a custom config.
)
context.run(trainer)
I decided to use TFX because I thought it was supposed to handle fine large amounts of data…