Armeet Singh Jatyani

Founder · AI Researcher & Engineer

Back to blog

Telegram Alerts for Training

I often use Weights & Biases for training. Others might use something like Tensorboard or mlflow. While these are great for tracking experiments and sweeps or monitoring logs, sometimes you want a simple alert system that'll wake you up when a run starts or fails (but with a little more control so you can quickly get a view on experiment status).

While SLURM allows you to setup email alerts, these end up clogging up my inbox, and I can't log on a more granular level. I've found that implementing a simple Telegram alerts/logger works great for me. I can get a high-level overview of my experiments on my phone when I'm on the go and even setup alerts that'll wake me up (this has helped me when working on my rebuttal or during time crunches, where a failed run can eat up precious time).

So here's the setup.

Setup Telegram

  1. Install Telegram.
  2. Follow this guide to setup a new Telegram bot and copy the token somewhere.
  3. In your SLURM batch script add the line:
    Shell
    export TELEGRAM_BOT_TOKEN=<your_token_here>
  4. Then, create a new chat and invite the bot to that chat. On phone it might be harder to find this. I go to Telegram on web and copy the chat ID from the url (excluding the hash symbol but including the negative - sign). telegram url screenshot
  5. Copy the chat ID and then add this line to your SLURM batch script:
    Shell
    export TELEGRAM_CHAT_ID=-12345...

Logger Code

We'll implement the logging as a PyTorch Lightning callback. You are free to implement a callback/hook in PyTorch as well.

Simply add it to your list of callbacks:

Python
callbacks: list[pl.Callback] = [
    ModelCheckpoint(
        monitor="val_mse",
        mode="min",
        dirpath=checkpoint_dir,
        save_top_k=-1,
        filename="{epoch}-{val_mse:.4f}",
    ),
]

# Add Telegram notifications if configured
telegram_bot_token = os.environ.get("TELEGRAM_BOT_TOKEN")
telegram_chat_id = os.environ.get("TELEGRAM_CHAT_ID")

if telegram_bot_token and telegram_chat_id and rank == 0:
    telegram_callback = TelegramNotificationCallback(
        bot_token=telegram_bot_token,
        chat_id=telegram_chat_id,
        run_name=run_name
    )
    callbacks.append(telegram_callback)
    print(f"Telegram notifications enabled for run: {run_name}")
...
trainer = pl.Trainer(
    ...
    callbacks=callbacks,
    ...
)

The logger class is not too difficult to understand either. You can sendMessage, editMessageText, and use some basic HTML styling. You could theoretically add buttons to allow for actions from your phone as well, but I haven't provided that here.

The following code updates a message with a progress bar and current model performance (PSNR score for MRI reconstruction. Here's what it looks like while training:

image showing telegram training alert with start message and live metric monitoring

And here is the code for the TelegramNotificationCallback class.

Python
class TelegramNotificationCallback(pl.Callback):
    """Custom callback for Telegram notifications during training."""
    
    def __init__(self, bot_token: str, chat_id: str, run_name: str):
        self.bot_token = bot_token
        self.chat_id = chat_id
        self.run_name = run_name
        self.start_time = None
        self.last_epoch_update = 0
        self.progress_message_id = None
        self.start_message_id = None
        
    def _send_telegram_message(self, message: str):
        """Send a message via Telegram bot."""
        try:
            url = f"https://api.telegram.org/bot{self.bot_token}/sendMessage"
            data = {
                "chat_id": self.chat_id,
                "text": message,
                "parse_mode": "HTML"
            }
            response = requests.post(url, data=data, timeout=10)
            response.raise_for_status()
            return response.json()["result"]["message_id"]
        except Exception as e:
            print(f"Failed to send Telegram message: {e}")
            return None
    
    def _edit_telegram_message(self, message_id: int, message: str):
        """Edit an existing Telegram message."""
        try:
            url = f"https://api.telegram.org/bot{self.bot_token}/editMessageText"
            data = {
                "chat_id": self.chat_id,
                "message_id": message_id,
                "text": message,
                "parse_mode": "HTML"
            }
            response = requests.post(url, data=data, timeout=10)
            response.raise_for_status()
        except Exception as e:
            print(f"Failed to edit Telegram message: {e}")
    
    def _create_progress_bar(self, current: int, total: int, width: int = 20) -> str:
        """Create a visual progress bar."""
        filled = int(width * current / total)
        bar = "" * filled + "" * (width - filled)
        return bar
    
    def _create_progress_message(self, trainer, current_epoch: int) -> str:
        """Create a progress message with visual elements."""
        if self.start_time is None:
            return ""
            
        max_epochs = trainer.max_epochs or 1
        progress = (current_epoch + 1) / max_epochs * 100
        elapsed_time = time.time() - self.start_time
        
        # Get current metrics with safe formatting
        train_loss = trainer.callback_metrics.get('train_loss', 'N/A')
        val_loss = trainer.callback_metrics.get('val_loss', 'N/A')
        val_mse = trainer.callback_metrics.get('val_mse', 'N/A')
        val_psnr = trainer.callback_metrics.get('val_psnr', 'N/A')
        
        # Helper function to safely format metrics
        def format_metric(metric):
            if isinstance(metric, torch.Tensor):
                return f"{metric.item():.4f}"
            elif isinstance(metric, (int, float)):
                return f"{metric:.4f}"
            else:
                return str(metric)
        
        # Create progress bar
        progress_bar = self._create_progress_bar(current_epoch + 1, max_epochs)
        
        # Add some animation characters
        animation_chars = ["", "", "", "", "", "", "", "", "", ""]
        anim_char = animation_chars[int(time.time() * 2) % len(animation_chars)]
        
        message = f"{anim_char} <b>Training Progress</b>\n\n"
        message += f"<b>Run:</b> {self.run_name}\n"
        message += f"<b>Epoch:</b> {current_epoch + 1}/{max_epochs} ({progress:.1f}%)\n"
        message += f"<b>Elapsed:</b> {elapsed_time/3600:.1f}h\n\n"
        message += f"<code>{progress_bar}</code>\n\n"
        message += f"<b>Train Loss:</b> {format_metric(train_loss)}\n"
        message += f"<b>Val Loss:</b> {format_metric(val_loss)}\n"
        message += f"<b>Val MSE:</b> {format_metric(val_mse)}\n"
        message += f"<b>Val PSNR:</b> {format_metric(val_psnr)}"
        
        return message
    
    def on_fit_start(self, trainer, pl_module):
        """Called when training starts."""
        if trainer.is_global_zero:  # Only send from rank 0
            self.start_time = time.time()
            message = f"🚀 <b>Training Started</b>\n\n"
            message += f"<b>Run:</b> {self.run_name}\n"
            message += f"<b>Model:</b> {pl_module.__class__.__name__}\n"
            message += f"<b>Max Epochs:</b> {trainer.max_epochs}\n"
            message += f"<b>Devices:</b> {trainer.num_devices}\n"
            message += f"<b>Nodes:</b> {trainer.num_nodes}"
            
            self.start_message_id = self._send_telegram_message(message)
            
            # Send initial progress message
            progress_message = self._create_progress_message(trainer, 0)
            self.progress_message_id = self._send_telegram_message(progress_message)
    
    def on_train_epoch_end(self, trainer, pl_module):
        """Called at the end of each training epoch."""
        if trainer.is_global_zero and self.progress_message_id is not None:
            current_epoch = trainer.current_epoch
            
            # Update progress message every epoch for smooth animation
            progress_message = self._create_progress_message(trainer, current_epoch)
            self._edit_telegram_message(self.progress_message_id, progress_message)
            
            # Update last_epoch_update for potential future use
            self.last_epoch_update = current_epoch
    
    def on_fit_end(self, trainer, pl_module):
        """Called when training ends."""
        if trainer.is_global_zero and self.start_time is not None:
            total_time = time.time() - self.start_time
            final_val_mse = trainer.callback_metrics.get('val_mse', 'N/A')
            final_val_psnr = trainer.callback_metrics.get('val_psnr', 'N/A')
            
            # Helper function to safely format metrics (same as in _create_progress_message)
            def format_metric(metric):
                if isinstance(metric, torch.Tensor):
                    return f"{metric.item():.4f}"
                elif isinstance(metric, (int, float)):
                    return f"{metric:.4f}"
                else:
                    return str(metric)
            
            # Update progress message to show completion
            if self.progress_message_id is not None:
                completion_message = f"✅ <b>Training Completed</b>\n\n"
                completion_message += f"<b>Run:</b> {self.run_name}\n"
                completion_message += f"<b>Total Time:</b> {total_time/3600:.1f}h\n"
                completion_message += f"<b>Final Val MSE:</b> {format_metric(final_val_mse)}\n"
                completion_message += f"<b>Final Val PSNR:</b> {format_metric(final_val_psnr)}\n\n"
                completion_message += f"<code>{'' * 20}</code> 100%"
                
                self._edit_telegram_message(self.progress_message_id, completion_message)
            
            # Send final completion message
            final_message = f"🎉 <b>Training Successfully Completed!</b>\n\n"
            final_message += f"<b>Run:</b> {self.run_name}\n"
            final_message += f"<b>Total Time:</b> {total_time/3600:.1f}h\n"
            final_message += f"<b>Final Val MSE:</b> {format_metric(final_val_mse)}\n"
            final_message += f"<b>Final Val PSNR:</b> {format_metric(final_val_psnr)}"
            
            self._send_telegram_message(final_message)

You could probably also modify this code to send you a link to your wandb experiment.