Training hybrid models is hard, and papers tend to gloss over the practical engineering work that goes into building good ones. The purpose of this cookbook is to enable other technical groups to hit the ground running when building their own hybrid (SSM, Transformer, MoE) models.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
If you found this repository helpful, please consider citing it using:
For context, we at Zyphra have built the following hybrid models:
The following datasets:
And the following engineering optimizations
For context, we at Zyphra have built the following hybrid models:
The following datasets:
And the following engineering optimizations
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
If you found this repository helpful, please consider citing it using:
For context, we at Zyphra have built the following hybrid models:
The following datasets:
And the following engineering optimizations
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
If you found this repository helpful, please consider citing it using:
For context, we at Zyphra have built the following hybrid models:
The following datasets:
And the following engineering optimizations
We present histograms depicting distribution of cluster sizes in all the datasets (see Fig. 7-11). Please, note that all the figures are in log-log scale. We see a significant drop in the number of clusters starting from the size of around 100. This drop is present both in DCLM and FineWeb-Edu2 (see Fig. 8 and 9 respectively), and most likely is explained by a combination of the deduplication strategy and quality when creating both datasets: DCLM deduplication was done individually within 10 shards, while FineWeb-Edu2 was deduplicated within every Common Crawl snapshot. We find that large clusters usually contain low quality material (repeated advertisements, license agreements templates, etc), so it’s not surprising that such documents were removed. Notably, DCLM still contained one cluster with the size close to 1 million documents, containing low quality documents seemingly coming from the advertisements (see Appendix).We find both Zyda-1and Dolma-CC contain a small amount of duplicates, which is expected, since both datasets were deduplicated globally by their authors. Remaining duplicates are likely false negatives from the initial deduplication procedure. Note, that distribution of duplicates clusters sizes of these two datasets (Fig. 10 and 11) don’t contain any sharp drops, but rather hyper exponentially decreases with cluster size.
Below is an example of the document from the largest cluster (~1M documents) of duplicates in DCLM (quality score 0.482627):
Is safe? Is scam?
Is safe for your PC?
Is safe or is it scam?
Domain is SafeSafe score: 1
The higher the number, the more dangerous the website.Any number higher than 1 means DANGER.
Positive votes:
Negative votes:
Vote Up Vote Down review
Have you had bad experience with Warn us, please!
Below one will find a few documents with different quality scores from DCLM coming from the same duplicates cluster. Quality score varies from ~0.2 to ~0.04.
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.
We find, following miniCPM, that a simple curriculum training approach of increasing the proportion of higher quality tokens towards the end of training can significantly improve performance.
'High quality' is obviously subjective in part but we find that documents containing fact-rich information to be the most performant. Examples include:
In terms of the amount of annealing data, we find in general that more is better, although we are generally constrained by amount of available annealing data so that we have not been able to test truly large (>200B tokens) amounts of such data. This fits with the miniCPM findings of setting annealing to be about 10% of the total tokens of a run. We find that multiple epochs of annealing data do not appear to harm performance, yet beyond 2 epochs give little performance improvement.
Model Tech Reports Using Annealing
Papers on Annealing/LR
We performed significant ablations to explore the LR schedule. We made the following observations:
When doing annealing we find it is important to maintain a high 'replay fraction' of tokens from the original pre-training dataset to stabilize training and maintain performance. This is done both to extend the annealing phase so that the model has more optimizer steps to digest the annealing data, and to minimize forgetting of the original pre-training data distribution.
We typically find that a fraction of 50-70% 'replay' tokens from the original pre-training dataset and 50-30% tokens from the annealing datasets is optimal. Within this range, we find that the sensitivity to the exact replay fraction is quite low, yet we hold the intuition that replay should scale with the magnitude of the distribution shift between the pre-training and annealing datasets. In general, we have found annealing to be fairly robust to hyperparameter choices as long as the initial settings are sensible.
Concretely, our reccomendations for annealing are:
Dense transformer models (i.e. alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks) have dominated the DL model space for a long time. The reason for this is simple:
Lots of LLM blocks (e.g. MHA, MLPs, RWKV, Mamba, KANs, xLSTM, etc) boil down to perform very similar modeling tasks. We at Zyphra intuit that the ingredients for a good LLM architecture are:
Typically, these components are alternated so that the sequence is mixed, the per-token representations are updated, the sequence is mixed again etc. A careful balance of sequence and token mixing is required for good performance.
Therefore, potential LLM architectures should be evaluated on whether they:
The deployment context determines which of these properties is most important, for example:
For larger models, the primary determinant of performance is scale in terms of parameters and data which reduces the importance of architectural changes except insofar as they change the scaling law coefficients. However, at smaller scales when e.g. the parameter count is fixed by hard memory limitations, architectural efficiencies which give constant improvements to performance at a given scale become important and can enable models to significantly outperform for a given inference FLOP and memory budget. This effect is also seen in training where superior architecture enables models to compete with standard transformers which are trained on significantly more tokens (requiring significantly more FLOPs) since training far past chinchilla optimal models at fixed parameter count runs into strongly sublinear scaling. Because of this, a small absolute improvement in performance due to architecture can overcome a 2-10x token budget advantage far from the chinchilla optimal point, as we observe with our Zamba1 and Zamba2 models.
Since Zyphra seeks to build personalized on-device models, this cookbook will be focused on the practical implications of architectures falling into the smaller-model regime #2. We also focus heavily on architectural innovations to maximize the loss-decrease per parameter and per inference FLOP.
The key current focus of innovation is on the sequence mixer. This is because attention is expensive at long sequence lengths while MLPs appear close to maximal efficiency. While much is still uncertain, there appears to be converging evidence that alternative linear attention variants such as Mamba, RWKV, RetNet perform well at short context language modelling while being lacking at long-context reasoning, information retrieval, and in-context learning. However, despite this slight deficit on some aspects of performance, they are significantly more FLOP and memory efficient than attention layers.
This motivates a hybrid architecture which mixes attention and linear sequence mixers such as Mamba. This way, the majority of the sequence mixers are more efficient than attention while just enough full attention is used to maintain performance. Empirically, it appears that full attention is not needed every single sequence mixer but that substantially less attention can be used, which is what enables hybrids to work empirically. A similar findings have also recently been popularized applied to transformers with some recent models such as those used by CharacterAI claiming to alternate sliding-window-attention over a small window and full attention blocks. This has an equivalent effect of using cheap local sequence mixers and occasionally full attention but is less efficient than Mamba since sliding-window-attention is less efficient per FLOP than a Mamba block. The likely reason for this relates to the data distirbution. Natural language is often surprisingly predictable from primarily local correlations -- i.e. see the surprising effectiveness of pure N-gram models. However, occasionally, there is long-term information retrieval or other in-context learning required which a smaller number of attention layers can handle. In our experiments, we observe that between only 1/4 or 1/6 sequence mixer layer should be full attention, a phenomenon also reported here.
While several other works, such as Jamba, have explored SSM hybrid models at scale, with Zamba we have further improved the architecture on a performance-per-parameter metric. We have done this by utilizing a parameter-sharing scheme whereby a single transformer block consisting of an attention and a MLP block is re-used multiple times throughout the network. This comprises the only attention in the network. This increases the performance of the network for a given parameter count at the expense of additional FLOPs for the multiple invocations of the shared parameters. However, given the inherent FLOP efficiency of our Mamba backbone, the end result is an architecture that outperforms transformers in both equi-token and equi-FLOP conditions.
What the success of this architecture implies is that even when attention is used rarely, there is still a great redundancy in the attention parameters -- namely that the vast majority of them are not needed. While sequencing mixing via full MHA is necessary regularly, somehow the attention block itself does not have to have separate parameters. We conjecture that this means that in fact the attention is primarily needed to 'remind' the network of the past sequence in a few stereotyped ways and not necessarily to perform novel sequence mixing operations at every attention block. In any case, the Zamba architecture exploits this regularity to reduce the parameter count of the model for a given level of performance.
An additional change we made to the architecture, which turned out to be surprisingly important, is to concatenate the original text embeddings with the current layer embeddings at every shared attention block. We found this provided the biggest boost (other than the shared layer) to performance per parameter, while again increasing FLOPs slightly. We conjecture that by doing this, we are effectively 'reminding' the network continually of what the input tokens are while otherwise the processing in the residual stream may 'forget' them or be unable to retrieve them in a different context than they were originally processed. While in theory the residual stream itself was originally designed to ameliorate this type of forgetting, the fact that this concatenation approach works implies it is not entirely successful.
Beyond this, in later Zamba2 models we also applied LoRAs to the shared layers. This allows us to further specialize the shared blocks which slightly improves performance at a very small parameter cost. Using LoRAs in this way during pretraining is unusual and we believe it is an underexplored avenue for creating extremely parameter-efficient models.
Let's talk about model architectures. Why do we think hybrids offer the best model quality per training/inference FLOPs?
Dense transformers, are primarily composed of alternating multi-head attention (MHA) and multilayer perceptron (MLP) blocks. We believe dense transformers have the following shortcomings:
Mixture of Experts (MoE) architectures introduce a router block that splits the input sequence(s) to appropriate MLP experts on a per-token basis. While the MoE has the inference latency of its forward-pass parameters, all parameters need to be loaded into VRAM which often means inference can only be performed distributed across GPU clusters for large models.
State Space Models (SSM) offer a more efficient alternative to traditional attention mechanisms, particularly beneficial for smaller models deployed on devices with strict power and memory constraints. Models like Mamba and RWKV leverage these architectures to achieve competitive performance with significantly lower FLOP and memory requirements.
However, the exact cross-sequence dependencies of attention is hard to beat, and models without attention can require significantly more tokens to match the performance of attention-based models (Falcon Mamba 7b). Whether such attention-free models can ever fully match the performance of attention-based models on specific tasks like in-context learning and long-context reasoning is an open question.
Dense hybrid architectures combine the strengths of both dense transformers and SSMs. They don't introduce the memory overhead of MoEs, maintain the exact cross-sequence dependencies of attention, and have inference latency close to pure SSMs.
During the model planning phase, it's common to calculate what models will fit into a given budget of parameters, FLOPs, and inference/training memory. In this cookbook we present scripts we use internally to compute the parameters and FLOPs for a given model architecture and sizing. We see this as an extension of the EleutherAI cookbook but specialized to SSMs and hybrid models.
We create calculation scripts for the parameters and FLOPs of mamba models here as well as a detailed walkthrough of the calculations performed in these scripts.
For dense and MoE transformers, we recommend using the EleutherAI cookbook by Quentin Anthony, Hailey Schoelkopf, and Stella Biderman.
We provide a script at here that tokenizes text data from a Hugging Face dataset, calculates the total number of tokens, and optionally saves the tokenized dataset.
We provide computation benchmarks for hybrid model blocks such as attention, Mamba1, and Mamba2 here. These are useful for comparing hardware performance and for efficiently sizing models.
For communication benchmarks, there are two levels of tests:
In this cookbook, we provide framework-level benchmarks in Jax here. Why Jax when our model training code is in PyTorch? Because we needed to deeply understand the communication behavior of Jax comms for our Tree Attention work!
We perform all our training using PyTorch within our custom internal fork of MegatronLM. For smaller models we only need to utilize Zero-1 to shard optimizer states. For larger models such as Zamba-7B, we utilized tensor-parallelism (TP) for which we created our own custom implementation in both Mamba and Mamba2. We also utilized expert-parallelism (EP) for training BlackMamba.